liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +126 -89
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +267 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/functional.py +67 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +63 -99
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/model/smollm3.py +189 -0
- liger_kernel/transformers/monkey_patch.py +389 -82
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
6
7
|
|
|
7
8
|
import transformers
|
|
@@ -28,6 +29,7 @@ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
|
28
29
|
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
29
30
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
30
31
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
32
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
31
33
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
32
34
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
33
35
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
@@ -54,7 +56,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
54
56
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
55
57
|
|
|
56
58
|
|
|
57
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
59
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
58
60
|
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
59
61
|
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
60
62
|
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
@@ -64,26 +66,29 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
64
66
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
65
67
|
)
|
|
66
68
|
module.modules_to_save.default.in_place = in_place
|
|
69
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
67
70
|
module.original_module.offset = offset
|
|
68
71
|
module.original_module.casting_mode = casting_mode
|
|
69
72
|
module.original_module.variance_epsilon = (
|
|
70
73
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
71
74
|
)
|
|
72
75
|
module.original_module.in_place = in_place
|
|
76
|
+
module.original_module.row_mode = row_mode
|
|
73
77
|
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
74
78
|
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
75
79
|
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
76
80
|
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
77
|
-
module.modules_to_save.default
|
|
78
|
-
module.original_module
|
|
81
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
82
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
79
83
|
else:
|
|
80
84
|
module.offset = offset
|
|
81
85
|
module.casting_mode = casting_mode
|
|
82
86
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
83
87
|
module.in_place = in_place
|
|
88
|
+
module.row_mode = row_mode
|
|
84
89
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
85
90
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
86
|
-
module
|
|
91
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
87
92
|
|
|
88
93
|
|
|
89
94
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
@@ -105,28 +110,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
105
110
|
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
106
111
|
module, "normalized_shape", None
|
|
107
112
|
)
|
|
108
|
-
_bind_method_to_module(module.modules_to_save.default, "forward",
|
|
109
|
-
_bind_method_to_module(module.modules_to_save.default, "extra_repr",
|
|
110
|
-
_bind_method_to_module(module.original_module, "forward",
|
|
111
|
-
_bind_method_to_module(module.original_module, "extra_repr",
|
|
112
|
-
module.modules_to_save.default
|
|
113
|
-
module.original_module
|
|
113
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
114
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
115
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
116
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
118
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
114
119
|
else:
|
|
115
120
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
116
121
|
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
117
122
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
118
123
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
119
|
-
module
|
|
124
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
120
125
|
|
|
121
126
|
|
|
122
127
|
def _patch_swiglu_module(module, liger_module):
|
|
123
128
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
124
|
-
module
|
|
129
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
125
130
|
|
|
126
131
|
|
|
127
132
|
def _patch_geglu_module(module):
|
|
128
133
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
129
|
-
module
|
|
134
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
130
135
|
|
|
131
136
|
|
|
132
137
|
def apply_liger_kernel_to_granite(
|
|
@@ -257,10 +262,16 @@ def apply_liger_kernel_to_llama(
|
|
|
257
262
|
|
|
258
263
|
if fused_linear_cross_entropy:
|
|
259
264
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
260
|
-
|
|
265
|
+
if model is not None:
|
|
266
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
267
|
+
else:
|
|
268
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
261
269
|
else: # if version < 4.46.1
|
|
262
270
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
263
|
-
|
|
271
|
+
if model is not None:
|
|
272
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
273
|
+
else:
|
|
274
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
264
275
|
|
|
265
276
|
if model is not None:
|
|
266
277
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -280,6 +291,77 @@ def apply_liger_kernel_to_llama(
|
|
|
280
291
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
281
292
|
|
|
282
293
|
|
|
294
|
+
def apply_liger_kernel_to_smollm3(
|
|
295
|
+
rope: bool = True,
|
|
296
|
+
cross_entropy: bool = False,
|
|
297
|
+
fused_linear_cross_entropy: bool = True,
|
|
298
|
+
rms_norm: bool = True,
|
|
299
|
+
swiglu: bool = True,
|
|
300
|
+
model: PreTrainedModel = None,
|
|
301
|
+
) -> None:
|
|
302
|
+
"""
|
|
303
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
307
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
308
|
+
fused_linear_cross_entropy (bool):
|
|
309
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
310
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
311
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
312
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
313
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
314
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
315
|
+
loaded. Default is None.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
319
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
323
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
324
|
+
|
|
325
|
+
if rope:
|
|
326
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
327
|
+
if rms_norm:
|
|
328
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
329
|
+
if swiglu:
|
|
330
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
331
|
+
|
|
332
|
+
if cross_entropy:
|
|
333
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
334
|
+
from transformers.loss.loss_utils import nn
|
|
335
|
+
|
|
336
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
337
|
+
else:
|
|
338
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
339
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
340
|
+
|
|
341
|
+
if fused_linear_cross_entropy:
|
|
342
|
+
if model is not None:
|
|
343
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
344
|
+
else:
|
|
345
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
346
|
+
|
|
347
|
+
if model is not None:
|
|
348
|
+
# The model instance already exists, so we need to additionally patch the
|
|
349
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
350
|
+
|
|
351
|
+
# get the base model from the model instance
|
|
352
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
353
|
+
|
|
354
|
+
if rms_norm:
|
|
355
|
+
_patch_rms_norm_module(base_model.norm)
|
|
356
|
+
|
|
357
|
+
for decoder_layer in base_model.layers:
|
|
358
|
+
if swiglu:
|
|
359
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
360
|
+
if rms_norm:
|
|
361
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
362
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
363
|
+
|
|
364
|
+
|
|
283
365
|
def apply_liger_kernel_to_llava(
|
|
284
366
|
cross_entropy: bool = False,
|
|
285
367
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -314,13 +396,20 @@ def apply_liger_kernel_to_llava(
|
|
|
314
396
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
315
397
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
316
398
|
if fused_linear_cross_entropy:
|
|
317
|
-
if transformer_version >= version.parse("4.
|
|
318
|
-
|
|
399
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
400
|
+
if model is not None:
|
|
401
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
402
|
+
else:
|
|
403
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
404
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
405
|
+
if model is not None:
|
|
406
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
407
|
+
else:
|
|
408
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
319
409
|
else: # if version < 4.49.0
|
|
320
410
|
logger.warning(
|
|
321
|
-
"
|
|
411
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
322
412
|
)
|
|
323
|
-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
324
413
|
|
|
325
414
|
if model is not None:
|
|
326
415
|
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
@@ -359,6 +448,92 @@ def apply_liger_kernel_to_llava(
|
|
|
359
448
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
360
449
|
|
|
361
450
|
|
|
451
|
+
def apply_liger_kernel_to_llama4(
|
|
452
|
+
rope: bool = False,
|
|
453
|
+
cross_entropy: bool = False,
|
|
454
|
+
fused_linear_cross_entropy: bool = True,
|
|
455
|
+
rms_norm: bool = True,
|
|
456
|
+
swiglu: bool = True,
|
|
457
|
+
model: PreTrainedModel = None,
|
|
458
|
+
layer_norm: bool = True,
|
|
459
|
+
) -> None:
|
|
460
|
+
"""
|
|
461
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
465
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
466
|
+
fused_linear_cross_entropy (bool):
|
|
467
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
468
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
469
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
470
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
471
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
472
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
473
|
+
loaded. Default is None.
|
|
474
|
+
"""
|
|
475
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
476
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
from transformers.models.llama4 import modeling_llama4
|
|
480
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
481
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
482
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
484
|
+
|
|
485
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
486
|
+
|
|
487
|
+
if rope:
|
|
488
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
|
|
489
|
+
if rms_norm:
|
|
490
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
491
|
+
if swiglu:
|
|
492
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
493
|
+
|
|
494
|
+
if cross_entropy:
|
|
495
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
496
|
+
|
|
497
|
+
if fused_linear_cross_entropy:
|
|
498
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
499
|
+
|
|
500
|
+
if model is not None:
|
|
501
|
+
# The model instance already exists, so we need to additionally patch the
|
|
502
|
+
# instance variables that reference already-instantiated modules
|
|
503
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
504
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
505
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
506
|
+
text_model: Llama4TextModel = language_model.model
|
|
507
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
508
|
+
text_model = model.model
|
|
509
|
+
vision_model = None
|
|
510
|
+
elif isinstance(model, Llama4TextModel):
|
|
511
|
+
text_model = model
|
|
512
|
+
vision_model = None
|
|
513
|
+
|
|
514
|
+
else:
|
|
515
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
516
|
+
|
|
517
|
+
if text_model:
|
|
518
|
+
if rms_norm:
|
|
519
|
+
_patch_rms_norm_module(text_model.norm)
|
|
520
|
+
for decoder_layer in text_model.layers:
|
|
521
|
+
if swiglu:
|
|
522
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
525
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
526
|
+
|
|
527
|
+
if vision_model:
|
|
528
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
529
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
530
|
+
|
|
531
|
+
for layer in vision_model.model.layers:
|
|
532
|
+
if layer_norm:
|
|
533
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
534
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
535
|
+
|
|
536
|
+
|
|
362
537
|
def apply_liger_kernel_to_mllama(
|
|
363
538
|
rope: bool = True,
|
|
364
539
|
cross_entropy: bool = False,
|
|
@@ -400,7 +575,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
400
575
|
|
|
401
576
|
if rope:
|
|
402
577
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
403
|
-
if layer_norm:
|
|
578
|
+
if layer_norm and model is None:
|
|
404
579
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
405
580
|
if rms_norm:
|
|
406
581
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -416,10 +591,16 @@ def apply_liger_kernel_to_mllama(
|
|
|
416
591
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
417
592
|
if fused_linear_cross_entropy:
|
|
418
593
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
419
|
-
|
|
594
|
+
if model is not None:
|
|
595
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
596
|
+
else:
|
|
597
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
420
598
|
else: # if version < 4.46.1
|
|
421
599
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
422
|
-
|
|
600
|
+
if model is not None:
|
|
601
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
602
|
+
else:
|
|
603
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
423
604
|
|
|
424
605
|
if model is not None:
|
|
425
606
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -428,13 +609,17 @@ def apply_liger_kernel_to_mllama(
|
|
|
428
609
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
429
610
|
language_model: MllamaForCausalLM = model.language_model
|
|
430
611
|
vision_model: MllamaVisionModel = model.vision_model
|
|
431
|
-
|
|
612
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
613
|
+
text_model: MllamaTextModel = language_model.model
|
|
614
|
+
else:
|
|
615
|
+
text_model = language_model
|
|
432
616
|
elif isinstance(model, MllamaForCausalLM):
|
|
433
617
|
text_model = model.model
|
|
434
618
|
vision_model = None
|
|
435
619
|
elif isinstance(model, MllamaTextModel):
|
|
436
620
|
text_model = model
|
|
437
621
|
vision_model = None
|
|
622
|
+
|
|
438
623
|
else:
|
|
439
624
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
440
625
|
|
|
@@ -501,7 +686,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
501
686
|
if cross_entropy:
|
|
502
687
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
503
688
|
if fused_linear_cross_entropy:
|
|
504
|
-
|
|
689
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
690
|
+
if model is not None:
|
|
691
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
692
|
+
else:
|
|
693
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
694
|
+
else:
|
|
695
|
+
logger.warning(
|
|
696
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
697
|
+
)
|
|
698
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
699
|
+
|
|
505
700
|
if swiglu:
|
|
506
701
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
507
702
|
|
|
@@ -569,10 +764,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
569
764
|
|
|
570
765
|
if fused_linear_cross_entropy:
|
|
571
766
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
572
|
-
|
|
767
|
+
if model is not None:
|
|
768
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
769
|
+
else:
|
|
770
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
573
771
|
else: # if version < 4.46.1
|
|
574
772
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
575
|
-
|
|
773
|
+
if model is not None:
|
|
774
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
775
|
+
else:
|
|
776
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
576
777
|
if swiglu:
|
|
577
778
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
578
779
|
|
|
@@ -626,8 +827,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
626
827
|
from transformers.models.gemma import modeling_gemma
|
|
627
828
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
628
829
|
|
|
629
|
-
|
|
630
|
-
|
|
830
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
831
|
+
|
|
631
832
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
632
833
|
|
|
633
834
|
if rope:
|
|
@@ -646,10 +847,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
646
847
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
647
848
|
if fused_linear_cross_entropy:
|
|
648
849
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
649
|
-
|
|
850
|
+
if model is not None:
|
|
851
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
852
|
+
else:
|
|
853
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
650
854
|
else: # if version < 4.46.1
|
|
651
855
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
652
|
-
|
|
856
|
+
if model is not None:
|
|
857
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
858
|
+
else:
|
|
859
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
653
860
|
|
|
654
861
|
if model is not None:
|
|
655
862
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -700,7 +907,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
700
907
|
from transformers.models.gemma2 import modeling_gemma2
|
|
701
908
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
702
909
|
|
|
703
|
-
|
|
910
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
911
|
+
|
|
704
912
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
705
913
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
706
914
|
)
|
|
@@ -720,10 +928,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
720
928
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
721
929
|
if fused_linear_cross_entropy:
|
|
722
930
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
723
|
-
|
|
931
|
+
if model is not None:
|
|
932
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
933
|
+
else:
|
|
934
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
724
935
|
else:
|
|
725
936
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
726
|
-
|
|
937
|
+
if model is not None:
|
|
938
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
939
|
+
else:
|
|
940
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
727
941
|
if geglu:
|
|
728
942
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
729
943
|
|
|
@@ -777,9 +991,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
777
991
|
from transformers.models.gemma3 import modeling_gemma3
|
|
778
992
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
779
993
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
994
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
780
995
|
|
|
781
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
|
782
996
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
997
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
783
998
|
|
|
784
999
|
_patch_rms_norm_module_for_gemma3 = partial(
|
|
785
1000
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
@@ -801,15 +1016,18 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
801
1016
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
802
1017
|
|
|
803
1018
|
if fused_linear_cross_entropy:
|
|
804
|
-
|
|
1019
|
+
if model is not None:
|
|
1020
|
+
model.forward = MethodType(causal_forward, model)
|
|
1021
|
+
else:
|
|
1022
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
805
1023
|
|
|
806
1024
|
if model is not None:
|
|
807
1025
|
# The model instance already exists, so we need to additionally patch the
|
|
808
1026
|
# instance variables that reference already-instantiated modules
|
|
809
1027
|
|
|
810
|
-
if isinstance(model, Gemma3ForCausalLM):
|
|
1028
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
811
1029
|
# get the base model from the model instance
|
|
812
|
-
base_model = model.model
|
|
1030
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
813
1031
|
|
|
814
1032
|
if rms_norm:
|
|
815
1033
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
@@ -871,7 +1089,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
871
1089
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
872
1090
|
)
|
|
873
1091
|
|
|
874
|
-
if layer_norm:
|
|
1092
|
+
if layer_norm and model is None:
|
|
875
1093
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
876
1094
|
|
|
877
1095
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -882,7 +1100,10 @@ def apply_liger_kernel_to_gemma3(
|
|
|
882
1100
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
883
1101
|
|
|
884
1102
|
if fused_linear_cross_entropy:
|
|
885
|
-
|
|
1103
|
+
if model is not None:
|
|
1104
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1105
|
+
else:
|
|
1106
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
886
1107
|
|
|
887
1108
|
if model is not None:
|
|
888
1109
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -950,7 +1171,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
950
1171
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
951
1172
|
|
|
952
1173
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1174
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
953
1175
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1176
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
954
1177
|
from transformers.models.paligemma import modeling_paligemma
|
|
955
1178
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
956
1179
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -961,7 +1184,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
961
1184
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
962
1185
|
|
|
963
1186
|
# The vision_tower is a SiglipVisionModel
|
|
964
|
-
if layer_norm:
|
|
1187
|
+
if layer_norm and model is None:
|
|
965
1188
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
966
1189
|
|
|
967
1190
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -979,10 +1202,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
979
1202
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
980
1203
|
if fused_linear_cross_entropy:
|
|
981
1204
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
982
|
-
|
|
1205
|
+
if model is not None:
|
|
1206
|
+
model.forward = MethodType(lce_forward, model)
|
|
1207
|
+
else:
|
|
1208
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
983
1209
|
else: # if version < 4.46.1
|
|
984
1210
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
985
|
-
|
|
1211
|
+
if model is not None:
|
|
1212
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1213
|
+
else:
|
|
1214
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
986
1215
|
|
|
987
1216
|
if model is not None:
|
|
988
1217
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1003,7 +1232,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1003
1232
|
|
|
1004
1233
|
language_model = model.language_model
|
|
1005
1234
|
|
|
1006
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1235
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1007
1236
|
apply_liger_kernel_to_gemma(
|
|
1008
1237
|
rope=rope,
|
|
1009
1238
|
cross_entropy=False,
|
|
@@ -1013,7 +1242,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1013
1242
|
model=language_model,
|
|
1014
1243
|
)
|
|
1015
1244
|
|
|
1016
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1245
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1017
1246
|
apply_liger_kernel_to_gemma2(
|
|
1018
1247
|
rope=rope,
|
|
1019
1248
|
cross_entropy=False,
|
|
@@ -1074,10 +1303,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1074
1303
|
|
|
1075
1304
|
if fused_linear_cross_entropy:
|
|
1076
1305
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1077
|
-
|
|
1306
|
+
if model is not None:
|
|
1307
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1308
|
+
else:
|
|
1309
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1078
1310
|
else: # if version < 4.46.1
|
|
1079
1311
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1080
|
-
|
|
1312
|
+
if model is not None:
|
|
1313
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1314
|
+
else:
|
|
1315
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1081
1316
|
|
|
1082
1317
|
if swiglu:
|
|
1083
1318
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1133,7 +1368,10 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1133
1368
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1134
1369
|
|
|
1135
1370
|
if fused_linear_cross_entropy:
|
|
1136
|
-
|
|
1371
|
+
if model is not None:
|
|
1372
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1373
|
+
else:
|
|
1374
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1137
1375
|
|
|
1138
1376
|
if swiglu:
|
|
1139
1377
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
@@ -1188,7 +1426,10 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1188
1426
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1189
1427
|
|
|
1190
1428
|
if fused_linear_cross_entropy:
|
|
1191
|
-
|
|
1429
|
+
if model is not None:
|
|
1430
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1431
|
+
else:
|
|
1432
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1192
1433
|
|
|
1193
1434
|
if swiglu:
|
|
1194
1435
|
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
@@ -1204,7 +1445,8 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1204
1445
|
_patch_rms_norm_module(base_model.norm)
|
|
1205
1446
|
for decoder_layer in base_model.layers:
|
|
1206
1447
|
if swiglu:
|
|
1207
|
-
|
|
1448
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1449
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1208
1450
|
if rms_norm:
|
|
1209
1451
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1210
1452
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -1221,7 +1463,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1221
1463
|
) -> None:
|
|
1222
1464
|
"""
|
|
1223
1465
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1224
|
-
NOTE: Qwen2-VL is not
|
|
1466
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1225
1467
|
|
|
1226
1468
|
Args:
|
|
1227
1469
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
@@ -1235,12 +1477,19 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1235
1477
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1236
1478
|
loaded. Default is None.
|
|
1237
1479
|
"""
|
|
1480
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1481
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1482
|
+
return
|
|
1483
|
+
|
|
1238
1484
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1239
1485
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1240
1486
|
)
|
|
1241
1487
|
|
|
1242
1488
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1489
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1490
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1243
1491
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1492
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1244
1493
|
|
|
1245
1494
|
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1246
1495
|
|
|
@@ -1249,12 +1498,15 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1249
1498
|
if rms_norm:
|
|
1250
1499
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1251
1500
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1252
|
-
if layer_norm:
|
|
1501
|
+
if layer_norm and model is None:
|
|
1253
1502
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1254
1503
|
if cross_entropy:
|
|
1255
1504
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1256
1505
|
if fused_linear_cross_entropy:
|
|
1257
|
-
|
|
1506
|
+
if model is not None:
|
|
1507
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1508
|
+
else:
|
|
1509
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1258
1510
|
if swiglu:
|
|
1259
1511
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1260
1512
|
|
|
@@ -1262,24 +1514,38 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1262
1514
|
# The model instance already exists, so we need to additionally patch the
|
|
1263
1515
|
# instance variables that reference already-instantiated modules
|
|
1264
1516
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1517
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1518
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1519
|
+
# Not sure if it is subject to changes in the future.
|
|
1520
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1521
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1522
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1523
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1524
|
+
text_model: Qwen2VLTextModel = model
|
|
1525
|
+
vision_model = None
|
|
1526
|
+
else:
|
|
1527
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1528
|
+
raise TypeError(
|
|
1529
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1530
|
+
)
|
|
1267
1531
|
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
for vision_block in
|
|
1532
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1533
|
+
if vision_model is not None:
|
|
1534
|
+
for vision_block in vision_model.blocks:
|
|
1271
1535
|
if layer_norm:
|
|
1272
1536
|
_patch_layer_norm_module(vision_block.norm1)
|
|
1273
1537
|
_patch_layer_norm_module(vision_block.norm2)
|
|
1274
1538
|
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
for decoder_layer in base_model.layers:
|
|
1278
|
-
if swiglu:
|
|
1279
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1539
|
+
# Patch Qwen2VisionTextModel
|
|
1540
|
+
if text_model is not None:
|
|
1280
1541
|
if rms_norm:
|
|
1281
|
-
_patch_rms_norm_module(
|
|
1282
|
-
|
|
1542
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1543
|
+
for decoder_layer in text_model.layers:
|
|
1544
|
+
if swiglu:
|
|
1545
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1546
|
+
if rms_norm:
|
|
1547
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1548
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1283
1549
|
|
|
1284
1550
|
|
|
1285
1551
|
def apply_liger_kernel_to_qwen2_5_vl(
|
|
@@ -1305,12 +1571,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1305
1571
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1306
1572
|
loaded. Default is None.
|
|
1307
1573
|
"""
|
|
1574
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1575
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1576
|
+
return
|
|
1577
|
+
|
|
1308
1578
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1309
1579
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1310
1580
|
)
|
|
1311
1581
|
|
|
1312
1582
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1583
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1584
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1313
1585
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1586
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1314
1587
|
|
|
1315
1588
|
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1316
1589
|
|
|
@@ -1321,7 +1594,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1321
1594
|
if cross_entropy:
|
|
1322
1595
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1323
1596
|
if fused_linear_cross_entropy:
|
|
1324
|
-
|
|
1597
|
+
if model is not None:
|
|
1598
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1599
|
+
else:
|
|
1600
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1325
1601
|
if swiglu:
|
|
1326
1602
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1327
1603
|
|
|
@@ -1329,24 +1605,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1329
1605
|
# The model instance already exists, so we need to additionally patch the
|
|
1330
1606
|
# instance variables that reference already-instantiated modules
|
|
1331
1607
|
|
|
1332
|
-
|
|
1333
|
-
|
|
1608
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1609
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1610
|
+
# Not sure if it is subject to changes in the future.
|
|
1611
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1612
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1613
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1614
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1615
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1616
|
+
vision_model = None
|
|
1617
|
+
else:
|
|
1618
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1619
|
+
raise TypeError(
|
|
1620
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1621
|
+
)
|
|
1334
1622
|
|
|
1335
|
-
if
|
|
1623
|
+
if vision_model is not None:
|
|
1336
1624
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1337
1625
|
for vision_block in model.visual.blocks:
|
|
1338
1626
|
if rms_norm:
|
|
1339
1627
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1340
1628
|
_patch_rms_norm_module(vision_block.norm2)
|
|
1341
1629
|
|
|
1342
|
-
if
|
|
1343
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1344
|
-
for decoder_layer in base_model.layers:
|
|
1345
|
-
if swiglu:
|
|
1346
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1630
|
+
if text_model is not None:
|
|
1347
1631
|
if rms_norm:
|
|
1348
|
-
_patch_rms_norm_module(
|
|
1349
|
-
|
|
1632
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1633
|
+
for decoder_layer in text_model.layers:
|
|
1634
|
+
if swiglu:
|
|
1635
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1636
|
+
if rms_norm:
|
|
1637
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1638
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1350
1639
|
|
|
1351
1640
|
|
|
1352
1641
|
def apply_liger_kernel_to_phi3(
|
|
@@ -1395,10 +1684,16 @@ def apply_liger_kernel_to_phi3(
|
|
|
1395
1684
|
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1396
1685
|
if fused_linear_cross_entropy:
|
|
1397
1686
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1398
|
-
|
|
1687
|
+
if model is not None:
|
|
1688
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1689
|
+
else:
|
|
1690
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1399
1691
|
else: # if version < 4.46.1
|
|
1400
1692
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1401
|
-
|
|
1693
|
+
if model is not None:
|
|
1694
|
+
model.forward = MethodType(phi3_lce_forward_deprecated, model)
|
|
1695
|
+
else:
|
|
1696
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
1402
1697
|
|
|
1403
1698
|
if model is not None:
|
|
1404
1699
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1449,11 +1744,12 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1449
1744
|
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1450
1745
|
|
|
1451
1746
|
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1747
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1452
1748
|
|
|
1453
1749
|
if rope:
|
|
1454
1750
|
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1455
1751
|
if rms_norm:
|
|
1456
|
-
modeling_olmo2.Olmo2RMSNorm =
|
|
1752
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1457
1753
|
if swiglu:
|
|
1458
1754
|
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1459
1755
|
if cross_entropy:
|
|
@@ -1461,7 +1757,10 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1461
1757
|
|
|
1462
1758
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1463
1759
|
if fused_linear_cross_entropy:
|
|
1464
|
-
|
|
1760
|
+
if model is not None:
|
|
1761
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1762
|
+
else:
|
|
1763
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1465
1764
|
|
|
1466
1765
|
if model is not None:
|
|
1467
1766
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1512,11 +1811,12 @@ def apply_liger_kernel_to_glm4(
|
|
|
1512
1811
|
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
1513
1812
|
|
|
1514
1813
|
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
1814
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
1515
1815
|
|
|
1516
1816
|
if rope:
|
|
1517
1817
|
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
1518
1818
|
if rms_norm:
|
|
1519
|
-
modeling_glm4.Glm4RMSNorm =
|
|
1819
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
1520
1820
|
if swiglu:
|
|
1521
1821
|
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
1522
1822
|
if cross_entropy:
|
|
@@ -1524,7 +1824,10 @@ def apply_liger_kernel_to_glm4(
|
|
|
1524
1824
|
|
|
1525
1825
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1526
1826
|
if fused_linear_cross_entropy:
|
|
1527
|
-
|
|
1827
|
+
if model is not None:
|
|
1828
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
1829
|
+
else:
|
|
1830
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1528
1831
|
|
|
1529
1832
|
if model is not None:
|
|
1530
1833
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1554,6 +1857,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1554
1857
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1555
1858
|
"glm4": apply_liger_kernel_to_glm4,
|
|
1556
1859
|
"llama": apply_liger_kernel_to_llama,
|
|
1860
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
1861
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1557
1862
|
"llava": apply_liger_kernel_to_llava,
|
|
1558
1863
|
"granite": apply_liger_kernel_to_granite,
|
|
1559
1864
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1565,7 +1870,10 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1565
1870
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1566
1871
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1567
1872
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1873
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1568
1874
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1875
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
1876
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1569
1877
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1570
1878
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
1571
1879
|
}
|
|
@@ -1625,7 +1933,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
1625
1933
|
return
|
|
1626
1934
|
|
|
1627
1935
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
1628
|
-
|
|
1629
1936
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
1630
1937
|
|
|
1631
1938
|
# Filter out the keyword arguments that are not supported by the apply function
|