liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +71 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +89 -69
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +25 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +44 -26
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +17 -8
- liger_kernel/transformers/model/gemma3.py +35 -16
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +37 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +584 -49
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.3.dist-info/RECORD +0 -111
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
|
|
|
20
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
23
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
24
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
25
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -34,6 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
|
|
|
34
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
35
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
36
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
37
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
38
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
39
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -428,7 +430,7 @@ def apply_liger_kernel_to_llava(
|
|
|
428
430
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
429
431
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
430
432
|
)
|
|
431
|
-
text_kwargs["model"] = model.language_model
|
|
433
|
+
text_kwargs["model"] = model.model.language_model
|
|
432
434
|
text_liger_fn(**text_kwargs)
|
|
433
435
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
434
436
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
@@ -443,7 +445,7 @@ def apply_liger_kernel_to_llava(
|
|
|
443
445
|
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
444
446
|
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
445
447
|
)
|
|
446
|
-
vision_kwargs["model"] = model.vision_tower
|
|
448
|
+
vision_kwargs["model"] = model.model.vision_tower
|
|
447
449
|
vision_liger_fn(**vision_kwargs)
|
|
448
450
|
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
449
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
@@ -613,8 +615,8 @@ def apply_liger_kernel_to_mllama(
|
|
|
613
615
|
# instance variables that reference already-instantiated modules
|
|
614
616
|
|
|
615
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
616
|
-
language_model: MllamaForCausalLM = model.language_model
|
|
617
|
-
vision_model: MllamaVisionModel = model.vision_model
|
|
618
|
+
language_model: MllamaForCausalLM = model.model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.model.vision_model
|
|
618
620
|
if isinstance(language_model, MllamaForCausalLM):
|
|
619
621
|
text_model: MllamaTextModel = language_model.model
|
|
620
622
|
else:
|
|
@@ -1116,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1116
1118
|
# instance variables that reference already-instantiated modules
|
|
1117
1119
|
|
|
1118
1120
|
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1119
|
-
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
1120
|
-
vision_tower = model.vision_tower
|
|
1121
|
+
if isinstance(model.model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.model.vision_tower
|
|
1121
1123
|
|
|
1122
1124
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1123
1125
|
|
|
@@ -1130,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1130
1132
|
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
1131
1133
|
|
|
1132
1134
|
if rms_norm:
|
|
1133
|
-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
|
|
1134
1136
|
|
|
1135
1137
|
apply_liger_kernel_to_gemma3_text(
|
|
1136
1138
|
rope=rope,
|
|
@@ -1138,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1138
1140
|
fused_linear_cross_entropy=False,
|
|
1139
1141
|
rms_norm=rms_norm,
|
|
1140
1142
|
geglu=geglu,
|
|
1141
|
-
model=model.language_model,
|
|
1143
|
+
model=model.model.language_model,
|
|
1142
1144
|
)
|
|
1143
1145
|
|
|
1144
1146
|
else:
|
|
@@ -1226,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1226
1228
|
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1227
1229
|
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1228
1230
|
|
|
1229
|
-
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.model.vision_tower
|
|
1230
1232
|
|
|
1231
1233
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1232
1234
|
|
|
@@ -1236,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1236
1238
|
_patch_layer_norm_module(layer.layer_norm1)
|
|
1237
1239
|
_patch_layer_norm_module(layer.layer_norm2)
|
|
1238
1240
|
|
|
1239
|
-
language_model = model.language_model
|
|
1241
|
+
language_model = model.model.language_model
|
|
1240
1242
|
|
|
1241
1243
|
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1242
1244
|
apply_liger_kernel_to_gemma(
|
|
@@ -1457,6 +1459,79 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1457
1459
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1458
1460
|
|
|
1459
1461
|
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1460
1535
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1461
1536
|
rope: bool = True,
|
|
1462
1537
|
cross_entropy: bool = False,
|
|
@@ -1518,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1518
1593
|
if model is not None:
|
|
1519
1594
|
# The model instance already exists, so we need to additionally patch the
|
|
1520
1595
|
# instance variables that reference already-instantiated modules
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1596
|
+
if isinstance(model, Qwen2VLForConditionalGeneration):
|
|
1597
|
+
text_model: Qwen2VLTextModel = model.model.language_model
|
|
1598
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
|
|
1599
|
+
elif isinstance(model, Qwen2VLModel):
|
|
1526
1600
|
text_model: Qwen2VLTextModel = model.language_model
|
|
1527
1601
|
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1528
1602
|
elif isinstance(model, Qwen2VLTextModel):
|
|
@@ -1609,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1609
1683
|
if model is not None:
|
|
1610
1684
|
# The model instance already exists, so we need to additionally patch the
|
|
1611
1685
|
# instance variables that reference already-instantiated modules
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1686
|
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
1687
|
+
text_model: Qwen2_5_VLTextModel = model.model.language_model
|
|
1688
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
|
|
1689
|
+
elif isinstance(model, Qwen2_5_VLModel):
|
|
1617
1690
|
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1618
1691
|
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1619
1692
|
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
@@ -1627,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1627
1700
|
|
|
1628
1701
|
if vision_model is not None:
|
|
1629
1702
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1630
|
-
for vision_block in
|
|
1703
|
+
for vision_block in vision_model.blocks:
|
|
1631
1704
|
if rms_norm:
|
|
1632
1705
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1633
1706
|
_patch_rms_norm_module(vision_block.norm2)
|
|
@@ -1643,6 +1716,162 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1643
1716
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1644
1717
|
|
|
1645
1718
|
|
|
1719
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1720
|
+
rope: bool = True,
|
|
1721
|
+
cross_entropy: bool = False,
|
|
1722
|
+
fused_linear_cross_entropy: bool = True,
|
|
1723
|
+
rms_norm: bool = True,
|
|
1724
|
+
swiglu: bool = False,
|
|
1725
|
+
model: PreTrainedModel = None,
|
|
1726
|
+
) -> None:
|
|
1727
|
+
"""
|
|
1728
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1729
|
+
|
|
1730
|
+
Args:
|
|
1731
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1732
|
+
fused_linear_cross_entropy (bool):
|
|
1733
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1734
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1735
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1736
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1737
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1738
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1739
|
+
loaded. Default is None.
|
|
1740
|
+
"""
|
|
1741
|
+
|
|
1742
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1743
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1747
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1748
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1749
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1750
|
+
|
|
1751
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1752
|
+
|
|
1753
|
+
if rope:
|
|
1754
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1755
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1756
|
+
|
|
1757
|
+
if rms_norm:
|
|
1758
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1759
|
+
|
|
1760
|
+
if cross_entropy:
|
|
1761
|
+
from transformers.loss.loss_utils import nn
|
|
1762
|
+
|
|
1763
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1764
|
+
|
|
1765
|
+
if fused_linear_cross_entropy:
|
|
1766
|
+
if model is not None:
|
|
1767
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1768
|
+
else:
|
|
1769
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1770
|
+
|
|
1771
|
+
if model is not None and rms_norm:
|
|
1772
|
+
if isinstance(model, Qwen3VLForConditionalGeneration):
|
|
1773
|
+
text_model: Qwen3VLTextModel = model.model.language_model
|
|
1774
|
+
elif isinstance(model, Qwen3VLModel):
|
|
1775
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1776
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1777
|
+
text_model = model
|
|
1778
|
+
else:
|
|
1779
|
+
raise TypeError(
|
|
1780
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1784
|
+
|
|
1785
|
+
if text_model is not None:
|
|
1786
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1787
|
+
for decoder_layer in text_model.layers:
|
|
1788
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1789
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1790
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1791
|
+
if self_attn is not None:
|
|
1792
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1793
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1794
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1799
|
+
rope: bool = True,
|
|
1800
|
+
cross_entropy: bool = False,
|
|
1801
|
+
fused_linear_cross_entropy: bool = True,
|
|
1802
|
+
rms_norm: bool = True,
|
|
1803
|
+
swiglu: bool = False,
|
|
1804
|
+
model: PreTrainedModel = None,
|
|
1805
|
+
) -> None:
|
|
1806
|
+
"""
|
|
1807
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1808
|
+
|
|
1809
|
+
Args:
|
|
1810
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1811
|
+
fused_linear_cross_entropy (bool):
|
|
1812
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1813
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1814
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1815
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1816
|
+
loaded. Default is None.
|
|
1817
|
+
"""
|
|
1818
|
+
|
|
1819
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1820
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1824
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1825
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1826
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1827
|
+
|
|
1828
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1833
|
+
|
|
1834
|
+
if rms_norm:
|
|
1835
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1836
|
+
|
|
1837
|
+
if cross_entropy:
|
|
1838
|
+
from transformers.loss.loss_utils import nn
|
|
1839
|
+
|
|
1840
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1841
|
+
|
|
1842
|
+
if fused_linear_cross_entropy:
|
|
1843
|
+
if model is not None:
|
|
1844
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1845
|
+
else:
|
|
1846
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1847
|
+
|
|
1848
|
+
if model is not None and rms_norm:
|
|
1849
|
+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeModel):
|
|
1852
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1853
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1854
|
+
text_model = model
|
|
1855
|
+
else:
|
|
1856
|
+
raise TypeError(
|
|
1857
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1861
|
+
|
|
1862
|
+
if text_model is not None:
|
|
1863
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1864
|
+
for decoder_layer in text_model.layers:
|
|
1865
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1866
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1867
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1868
|
+
if self_attn is not None:
|
|
1869
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1870
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1871
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1872
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1873
|
+
|
|
1874
|
+
|
|
1646
1875
|
def apply_liger_kernel_to_phi3(
|
|
1647
1876
|
rope: bool = True,
|
|
1648
1877
|
cross_entropy: bool = False,
|
|
@@ -1774,6 +2003,74 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1774
2003
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1775
2004
|
|
|
1776
2005
|
|
|
2006
|
+
def apply_liger_kernel_to_olmo3(
|
|
2007
|
+
rope: bool = True,
|
|
2008
|
+
cross_entropy: bool = False,
|
|
2009
|
+
fused_linear_cross_entropy: bool = True,
|
|
2010
|
+
rms_norm: bool = True,
|
|
2011
|
+
swiglu: bool = True,
|
|
2012
|
+
model: PreTrainedModel = None,
|
|
2013
|
+
) -> None:
|
|
2014
|
+
"""
|
|
2015
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
2016
|
+
|
|
2017
|
+
Args:
|
|
2018
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2019
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2020
|
+
fused_linear_cross_entropy (bool):
|
|
2021
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2022
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2023
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2024
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2025
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
2026
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2027
|
+
loaded. Default is None.
|
|
2028
|
+
"""
|
|
2029
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2030
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2031
|
+
)
|
|
2032
|
+
|
|
2033
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
2034
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
2035
|
+
|
|
2036
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
2037
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
2038
|
+
|
|
2039
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
2040
|
+
if rope:
|
|
2041
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2042
|
+
if rms_norm:
|
|
2043
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
2044
|
+
if swiglu:
|
|
2045
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
2046
|
+
if cross_entropy:
|
|
2047
|
+
from transformers.loss.loss_utils import nn
|
|
2048
|
+
|
|
2049
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2050
|
+
if fused_linear_cross_entropy:
|
|
2051
|
+
if model is not None:
|
|
2052
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
2053
|
+
else:
|
|
2054
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
2055
|
+
|
|
2056
|
+
if model is not None:
|
|
2057
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2058
|
+
# instance variables that reference already-instantiated modules
|
|
2059
|
+
|
|
2060
|
+
# get the base model from the model instance
|
|
2061
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
2062
|
+
|
|
2063
|
+
if rms_norm:
|
|
2064
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2065
|
+
|
|
2066
|
+
for decoder_layer in base_model.layers:
|
|
2067
|
+
if swiglu:
|
|
2068
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2069
|
+
if rms_norm:
|
|
2070
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2071
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2072
|
+
|
|
2073
|
+
|
|
1777
2074
|
def apply_liger_kernel_to_glm4(
|
|
1778
2075
|
rope: bool = False,
|
|
1779
2076
|
cross_entropy: bool = False,
|
|
@@ -1896,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
|
|
|
1896
2193
|
if model is not None:
|
|
1897
2194
|
# The model instance already exists, so we need to additionally patch the
|
|
1898
2195
|
# instance variables that reference already-instantiated modules
|
|
1899
|
-
if isinstance(model,
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
2196
|
+
if isinstance(model, Glm4vForConditionalGeneration):
|
|
2197
|
+
text_model: Glm4vTextModel = model.model.language_model
|
|
2198
|
+
vision_model: Glm4vVisionModel = model.model.visual
|
|
2199
|
+
elif isinstance(model, Glm4vModel):
|
|
1903
2200
|
text_model: Glm4vTextModel = model.language_model
|
|
1904
2201
|
vision_model: Glm4vVisionModel = model.visual
|
|
1905
2202
|
elif isinstance(model, Glm4vTextModel):
|
|
@@ -1986,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
|
|
|
1986
2283
|
if model is not None:
|
|
1987
2284
|
# The model instance already exists, so we need to additionally patch the
|
|
1988
2285
|
# instance variables that reference already-instantiated modules
|
|
1989
|
-
if isinstance(model,
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
2286
|
+
if isinstance(model, Glm4vMoeForConditionalGeneration):
|
|
2287
|
+
text_model: Glm4vMoeTextModel = model.model.language_model
|
|
2288
|
+
vision_model: Glm4vMoeVisionModel = model.model.visual
|
|
2289
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2290
|
+
elif isinstance(model, Glm4vMoeModel):
|
|
1993
2291
|
text_model: Glm4vMoeTextModel = model.language_model
|
|
1994
2292
|
vision_model: Glm4vMoeVisionModel = model.visual
|
|
1995
2293
|
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
@@ -2038,6 +2336,7 @@ def apply_liger_kernel_to_internvl(
|
|
|
2038
2336
|
cross_entropy: bool = False,
|
|
2039
2337
|
fused_linear_cross_entropy: bool = True,
|
|
2040
2338
|
rms_norm: bool = True,
|
|
2339
|
+
layer_norm: bool = True,
|
|
2041
2340
|
model: Optional[PreTrainedModel] = None,
|
|
2042
2341
|
**kwargs,
|
|
2043
2342
|
) -> None:
|
|
@@ -2048,37 +2347,62 @@ def apply_liger_kernel_to_internvl(
|
|
|
2048
2347
|
NOTE: InternVL is not available in transformers<4.52.1
|
|
2049
2348
|
|
|
2050
2349
|
Args:
|
|
2051
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2052
2350
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2053
2351
|
fused_linear_cross_entropy (bool):
|
|
2054
2352
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2055
2353
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2056
2354
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2057
2355
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2058
|
-
|
|
2356
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2059
2357
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2060
2358
|
loaded. Default is None.
|
|
2061
2359
|
"""
|
|
2062
2360
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2063
2361
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2064
2362
|
)
|
|
2363
|
+
import torch.nn as torch_nn
|
|
2065
2364
|
|
|
2066
2365
|
from transformers.models.internvl import modeling_internvl
|
|
2366
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2367
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2368
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2369
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2370
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2067
2371
|
|
|
2372
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2068
2373
|
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2374
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2375
|
+
|
|
2376
|
+
if layer_norm and model is None:
|
|
2377
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2069
2378
|
|
|
2070
2379
|
if cross_entropy:
|
|
2071
|
-
logger.
|
|
2072
|
-
|
|
2380
|
+
logger.info("Apply liger cross entropy")
|
|
2381
|
+
|
|
2382
|
+
from transformers.loss.loss_utils import nn
|
|
2383
|
+
|
|
2384
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2073
2385
|
if fused_linear_cross_entropy:
|
|
2074
2386
|
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2075
2387
|
if rms_norm:
|
|
2076
2388
|
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2077
2389
|
|
|
2078
2390
|
if model is not None:
|
|
2079
|
-
|
|
2391
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2392
|
+
# instance variables that reference already-instantiated modules
|
|
2393
|
+
if isinstance(model, InternVLForConditionalGeneration):
|
|
2394
|
+
text_model = model.model.language_model
|
|
2395
|
+
vision_model: InternVLVisionModel = model.model.vision_tower
|
|
2396
|
+
elif isinstance(model, InternVLModel):
|
|
2397
|
+
text_model = model.language_model
|
|
2398
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2399
|
+
else:
|
|
2400
|
+
raise TypeError(
|
|
2401
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2402
|
+
)
|
|
2403
|
+
|
|
2404
|
+
text_model_name = model.config.text_config.model_type
|
|
2080
2405
|
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2081
|
-
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
2082
2406
|
|
|
2083
2407
|
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2084
2408
|
if text_liger_fn:
|
|
@@ -2091,25 +2415,33 @@ def apply_liger_kernel_to_internvl(
|
|
|
2091
2415
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2092
2416
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2093
2417
|
)
|
|
2094
|
-
text_kwargs["model"] =
|
|
2418
|
+
text_kwargs["model"] = text_model
|
|
2095
2419
|
text_liger_fn(**text_kwargs)
|
|
2096
2420
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2097
2421
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2098
2422
|
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
|
|
2423
|
+
# Patch vision model RMSNorm layers
|
|
2424
|
+
if rms_norm:
|
|
2425
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2426
|
+
encoder_layer: InternVLVisionLayer
|
|
2427
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2428
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2429
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2430
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2103
2431
|
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2108
|
-
)
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
|
|
2112
|
-
|
|
2432
|
+
# Patch vision model LayerNorm layers
|
|
2433
|
+
if layer_norm:
|
|
2434
|
+
# Patch layernorm
|
|
2435
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2436
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2437
|
+
|
|
2438
|
+
# Patch encoder layers
|
|
2439
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2440
|
+
encoder_layer: InternVLVisionLayer
|
|
2441
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2442
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2443
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2444
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2113
2445
|
|
|
2114
2446
|
|
|
2115
2447
|
def apply_liger_kernel_to_smolvlm(
|
|
@@ -2372,6 +2704,200 @@ def apply_liger_kernel_to_qwen3_next(
|
|
|
2372
2704
|
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2373
2705
|
|
|
2374
2706
|
|
|
2707
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2708
|
+
rope: bool = True,
|
|
2709
|
+
cross_entropy: bool = False,
|
|
2710
|
+
fused_linear_cross_entropy: bool = True,
|
|
2711
|
+
rms_norm: bool = True,
|
|
2712
|
+
swiglu: bool = True,
|
|
2713
|
+
model: PreTrainedModel = None,
|
|
2714
|
+
) -> None:
|
|
2715
|
+
"""
|
|
2716
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2717
|
+
"""
|
|
2718
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2719
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2720
|
+
)
|
|
2721
|
+
|
|
2722
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2723
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2724
|
+
|
|
2725
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2726
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2727
|
+
|
|
2728
|
+
if rope:
|
|
2729
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2730
|
+
|
|
2731
|
+
if rms_norm:
|
|
2732
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2733
|
+
|
|
2734
|
+
if cross_entropy:
|
|
2735
|
+
from transformers.loss.loss_utils import nn
|
|
2736
|
+
|
|
2737
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2738
|
+
|
|
2739
|
+
if fused_linear_cross_entropy:
|
|
2740
|
+
if model is not None:
|
|
2741
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2742
|
+
else:
|
|
2743
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2744
|
+
|
|
2745
|
+
if swiglu:
|
|
2746
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2747
|
+
|
|
2748
|
+
if model is not None:
|
|
2749
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2750
|
+
# instance variables that reference already-instantiated modules
|
|
2751
|
+
|
|
2752
|
+
# get the base model from the model instance
|
|
2753
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
2754
|
+
|
|
2755
|
+
if rms_norm:
|
|
2756
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2757
|
+
for decoder_layer in base_model.layers:
|
|
2758
|
+
if swiglu:
|
|
2759
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2760
|
+
if rms_norm:
|
|
2761
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2762
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2763
|
+
|
|
2764
|
+
|
|
2765
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2766
|
+
rope: bool = True,
|
|
2767
|
+
cross_entropy: bool = False,
|
|
2768
|
+
fused_linear_cross_entropy: bool = True,
|
|
2769
|
+
rms_norm: bool = True,
|
|
2770
|
+
swiglu: bool = True,
|
|
2771
|
+
model: PreTrainedModel = None,
|
|
2772
|
+
) -> None:
|
|
2773
|
+
"""
|
|
2774
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2775
|
+
"""
|
|
2776
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2777
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2778
|
+
)
|
|
2779
|
+
|
|
2780
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2781
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2782
|
+
|
|
2783
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2784
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2785
|
+
|
|
2786
|
+
if rope:
|
|
2787
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2788
|
+
|
|
2789
|
+
if rms_norm:
|
|
2790
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2791
|
+
|
|
2792
|
+
if cross_entropy:
|
|
2793
|
+
from transformers.loss.loss_utils import nn
|
|
2794
|
+
|
|
2795
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2796
|
+
|
|
2797
|
+
if fused_linear_cross_entropy:
|
|
2798
|
+
if model is not None:
|
|
2799
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2800
|
+
else:
|
|
2801
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2802
|
+
|
|
2803
|
+
if swiglu:
|
|
2804
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2805
|
+
|
|
2806
|
+
if model is not None:
|
|
2807
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2808
|
+
# instance variables that reference already-instantiated modules
|
|
2809
|
+
|
|
2810
|
+
# get the base model from the model instance
|
|
2811
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2812
|
+
|
|
2813
|
+
if rms_norm:
|
|
2814
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2815
|
+
for decoder_layer in base_model.layers:
|
|
2816
|
+
if swiglu:
|
|
2817
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2818
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
2819
|
+
if rms_norm:
|
|
2820
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2821
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2822
|
+
|
|
2823
|
+
|
|
2824
|
+
def apply_liger_kernel_to_exaone4(
|
|
2825
|
+
rope: bool = True,
|
|
2826
|
+
cross_entropy: bool = False,
|
|
2827
|
+
fused_linear_cross_entropy: bool = True,
|
|
2828
|
+
rms_norm: bool = True,
|
|
2829
|
+
swiglu: bool = True,
|
|
2830
|
+
model: PreTrainedModel = None,
|
|
2831
|
+
) -> None:
|
|
2832
|
+
"""
|
|
2833
|
+
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
|
|
2834
|
+
|
|
2835
|
+
Args:
|
|
2836
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2837
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2838
|
+
fused_linear_cross_entropy (bool):
|
|
2839
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2840
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2841
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2842
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2843
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2844
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2845
|
+
loaded. Default is None.
|
|
2846
|
+
"""
|
|
2847
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2848
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2849
|
+
)
|
|
2850
|
+
|
|
2851
|
+
from transformers.models.exaone4 import modeling_exaone4
|
|
2852
|
+
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
|
|
2853
|
+
|
|
2854
|
+
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
|
|
2855
|
+
|
|
2856
|
+
if rope:
|
|
2857
|
+
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2858
|
+
|
|
2859
|
+
if rms_norm:
|
|
2860
|
+
# EXAONE4 requires in_place=False to avoid gradient issues
|
|
2861
|
+
class Exaone4LigerRMSNorm(LigerRMSNorm):
|
|
2862
|
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
2863
|
+
super().__init__(hidden_size, eps, **kwargs)
|
|
2864
|
+
self.in_place = False
|
|
2865
|
+
|
|
2866
|
+
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
|
|
2867
|
+
|
|
2868
|
+
if cross_entropy:
|
|
2869
|
+
from transformers.loss.loss_utils import nn
|
|
2870
|
+
|
|
2871
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2872
|
+
|
|
2873
|
+
if fused_linear_cross_entropy:
|
|
2874
|
+
if model is not None:
|
|
2875
|
+
model.forward = MethodType(exaone4_lce_forward, model)
|
|
2876
|
+
else:
|
|
2877
|
+
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
|
|
2878
|
+
|
|
2879
|
+
if swiglu:
|
|
2880
|
+
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
|
|
2881
|
+
|
|
2882
|
+
if model is not None:
|
|
2883
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2884
|
+
# instance variables that reference already-instantiated modules
|
|
2885
|
+
|
|
2886
|
+
# get the base model from the model instance
|
|
2887
|
+
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
|
|
2888
|
+
|
|
2889
|
+
if rms_norm:
|
|
2890
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2891
|
+
for decoder_layer in base_model.layers:
|
|
2892
|
+
if swiglu:
|
|
2893
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
2894
|
+
if rms_norm:
|
|
2895
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2896
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2897
|
+
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
|
|
2898
|
+
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
|
|
2899
|
+
|
|
2900
|
+
|
|
2375
2901
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
2376
2902
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
2377
2903
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -2381,6 +2907,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2381
2907
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2382
2908
|
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2383
2909
|
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2910
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2384
2911
|
"internvl": apply_liger_kernel_to_internvl,
|
|
2385
2912
|
"llama": apply_liger_kernel_to_llama,
|
|
2386
2913
|
"llama4_text": apply_liger_kernel_to_llama4,
|
|
@@ -2392,6 +2919,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2392
2919
|
"mistral": apply_liger_kernel_to_mistral,
|
|
2393
2920
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
2394
2921
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2922
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
2395
2923
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2396
2924
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2397
2925
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
@@ -2400,11 +2928,18 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2400
2928
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2401
2929
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2402
2930
|
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2931
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2932
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2933
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2934
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2403
2935
|
"smollm3": apply_liger_kernel_to_smollm3,
|
|
2404
2936
|
"phi3": apply_liger_kernel_to_phi3,
|
|
2405
2937
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2406
2938
|
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2407
2939
|
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2940
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2941
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2942
|
+
"exaone4": apply_liger_kernel_to_exaone4,
|
|
2408
2943
|
}
|
|
2409
2944
|
|
|
2410
2945
|
|