liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +14 -4
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +17 -16
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/gemma2.py +3 -3
- liger_kernel/transformers/model/gemma3.py +11 -5
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/loss_utils.py +6 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +196 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +5 -13
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.4.dist-info/RECORD +0 -118
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
|
|
|
20
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
23
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
24
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
25
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -34,8 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
|
|
|
34
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
35
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
36
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
37
|
-
from liger_kernel.transformers.rope import
|
|
38
|
-
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
39
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
40
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
41
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
|
|
|
430
430
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
431
431
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
432
432
|
)
|
|
433
|
-
text_kwargs["model"] = model.language_model
|
|
433
|
+
text_kwargs["model"] = model.model.language_model
|
|
434
434
|
text_liger_fn(**text_kwargs)
|
|
435
435
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
436
436
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
@@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
|
|
|
445
445
|
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
446
446
|
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
447
447
|
)
|
|
448
|
-
vision_kwargs["model"] = model.vision_tower
|
|
448
|
+
vision_kwargs["model"] = model.model.vision_tower
|
|
449
449
|
vision_liger_fn(**vision_kwargs)
|
|
450
450
|
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
451
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
@@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
|
|
|
615
615
|
# instance variables that reference already-instantiated modules
|
|
616
616
|
|
|
617
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
618
|
-
language_model: MllamaForCausalLM = model.language_model
|
|
619
|
-
vision_model: MllamaVisionModel = model.vision_model
|
|
618
|
+
language_model: MllamaForCausalLM = model.model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.model.vision_model
|
|
620
620
|
if isinstance(language_model, MllamaForCausalLM):
|
|
621
621
|
text_model: MllamaTextModel = language_model.model
|
|
622
622
|
else:
|
|
@@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1118
1118
|
# instance variables that reference already-instantiated modules
|
|
1119
1119
|
|
|
1120
1120
|
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1121
|
-
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
1122
|
-
vision_tower = model.vision_tower
|
|
1121
|
+
if isinstance(model.model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.model.vision_tower
|
|
1123
1123
|
|
|
1124
1124
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1125
1125
|
|
|
@@ -1132,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1132
1132
|
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
1133
1133
|
|
|
1134
1134
|
if rms_norm:
|
|
1135
|
-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
|
|
1136
1136
|
|
|
1137
1137
|
apply_liger_kernel_to_gemma3_text(
|
|
1138
1138
|
rope=rope,
|
|
@@ -1140,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1140
1140
|
fused_linear_cross_entropy=False,
|
|
1141
1141
|
rms_norm=rms_norm,
|
|
1142
1142
|
geglu=geglu,
|
|
1143
|
-
model=model.language_model,
|
|
1143
|
+
model=model.model.language_model,
|
|
1144
1144
|
)
|
|
1145
1145
|
|
|
1146
1146
|
else:
|
|
@@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1228
1228
|
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1229
1229
|
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1230
1230
|
|
|
1231
|
-
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.model.vision_tower
|
|
1232
1232
|
|
|
1233
1233
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1234
1234
|
|
|
@@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1238
1238
|
_patch_layer_norm_module(layer.layer_norm1)
|
|
1239
1239
|
_patch_layer_norm_module(layer.layer_norm2)
|
|
1240
1240
|
|
|
1241
|
-
language_model = model.language_model
|
|
1241
|
+
language_model = model.model.language_model
|
|
1242
1242
|
|
|
1243
1243
|
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1244
1244
|
apply_liger_kernel_to_gemma(
|
|
@@ -1459,6 +1459,79 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1459
1459
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
1460
|
|
|
1461
1461
|
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1462
1535
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1463
1536
|
rope: bool = True,
|
|
1464
1537
|
cross_entropy: bool = False,
|
|
@@ -1520,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1520
1593
|
if model is not None:
|
|
1521
1594
|
# The model instance already exists, so we need to additionally patch the
|
|
1522
1595
|
# instance variables that reference already-instantiated modules
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1596
|
+
if isinstance(model, Qwen2VLForConditionalGeneration):
|
|
1597
|
+
text_model: Qwen2VLTextModel = model.model.language_model
|
|
1598
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
|
|
1599
|
+
elif isinstance(model, Qwen2VLModel):
|
|
1528
1600
|
text_model: Qwen2VLTextModel = model.language_model
|
|
1529
1601
|
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1530
1602
|
elif isinstance(model, Qwen2VLTextModel):
|
|
@@ -1611,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1611
1683
|
if model is not None:
|
|
1612
1684
|
# The model instance already exists, so we need to additionally patch the
|
|
1613
1685
|
# instance variables that reference already-instantiated modules
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1686
|
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
1687
|
+
text_model: Qwen2_5_VLTextModel = model.model.language_model
|
|
1688
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
|
|
1689
|
+
elif isinstance(model, Qwen2_5_VLModel):
|
|
1619
1690
|
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1620
1691
|
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1621
1692
|
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
@@ -1629,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1629
1700
|
|
|
1630
1701
|
if vision_model is not None:
|
|
1631
1702
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1632
|
-
for vision_block in
|
|
1703
|
+
for vision_block in vision_model.blocks:
|
|
1633
1704
|
if rms_norm:
|
|
1634
1705
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1635
1706
|
_patch_rms_norm_module(vision_block.norm2)
|
|
@@ -1680,8 +1751,8 @@ def apply_liger_kernel_to_qwen3_vl(
|
|
|
1680
1751
|
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1681
1752
|
|
|
1682
1753
|
if rope:
|
|
1683
|
-
modeling_qwen3_vl.apply_rotary_pos_emb =
|
|
1684
|
-
modeling_qwen3_vl.apply_rotary_pos_emb_vision =
|
|
1754
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1755
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1685
1756
|
|
|
1686
1757
|
if rms_norm:
|
|
1687
1758
|
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
@@ -1698,7 +1769,9 @@ def apply_liger_kernel_to_qwen3_vl(
|
|
|
1698
1769
|
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1699
1770
|
|
|
1700
1771
|
if model is not None and rms_norm:
|
|
1701
|
-
if isinstance(model,
|
|
1772
|
+
if isinstance(model, Qwen3VLForConditionalGeneration):
|
|
1773
|
+
text_model: Qwen3VLTextModel = model.model.language_model
|
|
1774
|
+
elif isinstance(model, Qwen3VLModel):
|
|
1702
1775
|
text_model: Qwen3VLTextModel = model.language_model
|
|
1703
1776
|
elif isinstance(model, Qwen3VLTextModel):
|
|
1704
1777
|
text_model = model
|
|
@@ -1755,8 +1828,8 @@ def apply_liger_kernel_to_qwen3_vl_moe(
|
|
|
1755
1828
|
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1756
1829
|
|
|
1757
1830
|
if rope:
|
|
1758
|
-
modeling_qwen3_vl_moe.apply_rotary_pos_emb =
|
|
1759
|
-
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision =
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1760
1833
|
|
|
1761
1834
|
if rms_norm:
|
|
1762
1835
|
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
@@ -1773,7 +1846,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
|
|
|
1773
1846
|
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1774
1847
|
|
|
1775
1848
|
if model is not None and rms_norm:
|
|
1776
|
-
if isinstance(model,
|
|
1849
|
+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeModel):
|
|
1777
1852
|
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1778
1853
|
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1779
1854
|
text_model = model
|
|
@@ -2118,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
|
|
|
2118
2193
|
if model is not None:
|
|
2119
2194
|
# The model instance already exists, so we need to additionally patch the
|
|
2120
2195
|
# instance variables that reference already-instantiated modules
|
|
2121
|
-
if isinstance(model,
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
|
|
2196
|
+
if isinstance(model, Glm4vForConditionalGeneration):
|
|
2197
|
+
text_model: Glm4vTextModel = model.model.language_model
|
|
2198
|
+
vision_model: Glm4vVisionModel = model.model.visual
|
|
2199
|
+
elif isinstance(model, Glm4vModel):
|
|
2125
2200
|
text_model: Glm4vTextModel = model.language_model
|
|
2126
2201
|
vision_model: Glm4vVisionModel = model.visual
|
|
2127
2202
|
elif isinstance(model, Glm4vTextModel):
|
|
@@ -2208,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
|
|
|
2208
2283
|
if model is not None:
|
|
2209
2284
|
# The model instance already exists, so we need to additionally patch the
|
|
2210
2285
|
# instance variables that reference already-instantiated modules
|
|
2211
|
-
if isinstance(model,
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2286
|
+
if isinstance(model, Glm4vMoeForConditionalGeneration):
|
|
2287
|
+
text_model: Glm4vMoeTextModel = model.model.language_model
|
|
2288
|
+
vision_model: Glm4vMoeVisionModel = model.model.visual
|
|
2289
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2290
|
+
elif isinstance(model, Glm4vMoeModel):
|
|
2215
2291
|
text_model: Glm4vMoeTextModel = model.language_model
|
|
2216
2292
|
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2217
2293
|
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
@@ -2314,8 +2390,10 @@ def apply_liger_kernel_to_internvl(
|
|
|
2314
2390
|
if model is not None:
|
|
2315
2391
|
# The model instance already exists, so we need to additionally patch the
|
|
2316
2392
|
# instance variables that reference already-instantiated modules
|
|
2317
|
-
if isinstance(model,
|
|
2318
|
-
|
|
2393
|
+
if isinstance(model, InternVLForConditionalGeneration):
|
|
2394
|
+
text_model = model.model.language_model
|
|
2395
|
+
vision_model: InternVLVisionModel = model.model.vision_tower
|
|
2396
|
+
elif isinstance(model, InternVLModel):
|
|
2319
2397
|
text_model = model.language_model
|
|
2320
2398
|
vision_model: InternVLVisionModel = model.vision_tower
|
|
2321
2399
|
else:
|
|
@@ -2743,6 +2821,83 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
|
2743
2821
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2744
2822
|
|
|
2745
2823
|
|
|
2824
|
+
def apply_liger_kernel_to_exaone4(
|
|
2825
|
+
rope: bool = True,
|
|
2826
|
+
cross_entropy: bool = False,
|
|
2827
|
+
fused_linear_cross_entropy: bool = True,
|
|
2828
|
+
rms_norm: bool = True,
|
|
2829
|
+
swiglu: bool = True,
|
|
2830
|
+
model: PreTrainedModel = None,
|
|
2831
|
+
) -> None:
|
|
2832
|
+
"""
|
|
2833
|
+
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
|
|
2834
|
+
|
|
2835
|
+
Args:
|
|
2836
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2837
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2838
|
+
fused_linear_cross_entropy (bool):
|
|
2839
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2840
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2841
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2842
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2843
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2844
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2845
|
+
loaded. Default is None.
|
|
2846
|
+
"""
|
|
2847
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2848
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2849
|
+
)
|
|
2850
|
+
|
|
2851
|
+
from transformers.models.exaone4 import modeling_exaone4
|
|
2852
|
+
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
|
|
2853
|
+
|
|
2854
|
+
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
|
|
2855
|
+
|
|
2856
|
+
if rope:
|
|
2857
|
+
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2858
|
+
|
|
2859
|
+
if rms_norm:
|
|
2860
|
+
# EXAONE4 requires in_place=False to avoid gradient issues
|
|
2861
|
+
class Exaone4LigerRMSNorm(LigerRMSNorm):
|
|
2862
|
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
2863
|
+
super().__init__(hidden_size, eps, **kwargs)
|
|
2864
|
+
self.in_place = False
|
|
2865
|
+
|
|
2866
|
+
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
|
|
2867
|
+
|
|
2868
|
+
if cross_entropy:
|
|
2869
|
+
from transformers.loss.loss_utils import nn
|
|
2870
|
+
|
|
2871
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2872
|
+
|
|
2873
|
+
if fused_linear_cross_entropy:
|
|
2874
|
+
if model is not None:
|
|
2875
|
+
model.forward = MethodType(exaone4_lce_forward, model)
|
|
2876
|
+
else:
|
|
2877
|
+
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
|
|
2878
|
+
|
|
2879
|
+
if swiglu:
|
|
2880
|
+
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
|
|
2881
|
+
|
|
2882
|
+
if model is not None:
|
|
2883
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2884
|
+
# instance variables that reference already-instantiated modules
|
|
2885
|
+
|
|
2886
|
+
# get the base model from the model instance
|
|
2887
|
+
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
|
|
2888
|
+
|
|
2889
|
+
if rms_norm:
|
|
2890
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2891
|
+
for decoder_layer in base_model.layers:
|
|
2892
|
+
if swiglu:
|
|
2893
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
2894
|
+
if rms_norm:
|
|
2895
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2896
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2897
|
+
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
|
|
2898
|
+
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
|
|
2899
|
+
|
|
2900
|
+
|
|
2746
2901
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
2747
2902
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
2748
2903
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -2752,6 +2907,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2752
2907
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2753
2908
|
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2754
2909
|
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2910
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2755
2911
|
"internvl": apply_liger_kernel_to_internvl,
|
|
2756
2912
|
"llama": apply_liger_kernel_to_llama,
|
|
2757
2913
|
"llama4_text": apply_liger_kernel_to_llama4,
|
|
@@ -2783,6 +2939,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2783
2939
|
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2784
2940
|
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2785
2941
|
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2942
|
+
"exaone4": apply_liger_kernel_to_exaone4,
|
|
2786
2943
|
}
|
|
2787
2944
|
|
|
2788
2945
|
|
|
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
|
5
5
|
|
|
6
6
|
from torch.nn.modules.utils import _pair
|
|
7
7
|
|
|
8
|
-
from liger_kernel.ops
|
|
8
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class LigerMultiTokenAttention(nn.Module):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
|
|
4
|
-
from liger_kernel.ops
|
|
4
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
@@ -14,13 +14,18 @@ class LigerRMSNorm(nn.Module):
|
|
|
14
14
|
init_fn="ones",
|
|
15
15
|
in_place=True,
|
|
16
16
|
row_mode=None,
|
|
17
|
+
elementwise_affine=True,
|
|
17
18
|
):
|
|
18
19
|
super().__init__()
|
|
19
20
|
assert init_fn in [
|
|
20
21
|
"ones",
|
|
21
22
|
"zeros",
|
|
22
23
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
-
self.
|
|
24
|
+
self.elementwise_affine = elementwise_affine
|
|
25
|
+
if self.elementwise_affine:
|
|
26
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
27
|
+
else:
|
|
28
|
+
self.register_parameter("weight", None)
|
|
24
29
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
25
30
|
eps,
|
|
26
31
|
offset,
|
|
@@ -41,7 +46,7 @@ class LigerRMSNorm(nn.Module):
|
|
|
41
46
|
)
|
|
42
47
|
|
|
43
48
|
def extra_repr(self):
|
|
44
|
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
49
|
+
return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
45
50
|
|
|
46
51
|
|
|
47
52
|
class LigerRMSNormForGemma(LigerRMSNorm):
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
1
|
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
6
|
-
from liger_kernel.ops
|
|
5
|
+
from liger_kernel.ops import LigerRopeFunction
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
@@ -25,39 +24,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
25
24
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
def
|
|
27
|
+
def liger_rotary_pos_emb_vision(
|
|
29
28
|
q: torch.Tensor,
|
|
30
29
|
k: torch.Tensor,
|
|
31
30
|
cos: torch.Tensor,
|
|
32
31
|
sin: torch.Tensor,
|
|
33
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
34
|
-
unsqueeze_dim: int = 1,
|
|
35
32
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
"""
|
|
34
|
+
Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
|
|
35
|
+
Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
|
|
36
|
+
Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
40
|
+
with stride (num_heads * head_dim, head_dim, 1).
|
|
41
|
+
k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
42
|
+
with stride (num_heads * head_dim, head_dim, 1). Same as q.
|
|
43
|
+
cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
|
|
44
|
+
sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
|
|
48
|
+
"""
|
|
36
49
|
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
37
50
|
|
|
38
|
-
|
|
39
|
-
|
|
51
|
+
# tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
|
|
52
|
+
# also unsqueeze for batch dim
|
|
53
|
+
q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
54
|
+
k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
40
55
|
cos32 = cos.to(torch.float32)
|
|
41
56
|
sin32 = sin.to(torch.float32)
|
|
42
57
|
|
|
43
|
-
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32
|
|
44
|
-
return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def liger_rotary_pos_emb_with_cast_and_leading_batch(
|
|
48
|
-
q: torch.Tensor,
|
|
49
|
-
k: torch.Tensor,
|
|
50
|
-
cos: torch.Tensor,
|
|
51
|
-
sin: torch.Tensor,
|
|
52
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
53
|
-
unsqueeze_dim: int = 1,
|
|
54
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
55
|
-
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
56
|
-
|
|
57
|
-
q32 = q.to(torch.float32).unsqueeze(0)
|
|
58
|
-
k32 = k.to(torch.float32).unsqueeze(0)
|
|
59
|
-
cos32 = cos.to(torch.float32).unsqueeze(0)
|
|
60
|
-
sin32 = sin.to(torch.float32).unsqueeze(0)
|
|
58
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
|
|
61
59
|
|
|
62
|
-
|
|
63
|
-
|
|
60
|
+
# transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
|
|
61
|
+
# also squeeze out batch dim
|
|
62
|
+
q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
|
|
63
|
+
k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
|
|
64
|
+
return q_out, k_out
|
|
@@ -2,9 +2,9 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops
|
|
6
|
-
from liger_kernel.ops
|
|
7
|
-
from liger_kernel.ops
|
|
5
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
6
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
7
|
+
from liger_kernel.ops import apply_tiled_mlp
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class LigerTiledGEGLUMLP(nn.Module):
|
|
@@ -57,11 +57,7 @@ class LigerTiledGEGLUMLP(nn.Module):
|
|
|
57
57
|
Returns:
|
|
58
58
|
Output tensor of the same shape as input
|
|
59
59
|
"""
|
|
60
|
-
compute_params = [
|
|
61
|
-
self.gate_proj.weight,
|
|
62
|
-
self.up_proj.weight,
|
|
63
|
-
self.down_proj.weight,
|
|
64
|
-
]
|
|
60
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
65
61
|
|
|
66
62
|
return apply_tiled_mlp(
|
|
67
63
|
fn=self._mlp_forward,
|
|
@@ -118,11 +114,7 @@ class LigerTiledSwiGLUMLP(nn.Module):
|
|
|
118
114
|
Returns:
|
|
119
115
|
Output tensor of the same shape as input
|
|
120
116
|
"""
|
|
121
|
-
compute_params = [
|
|
122
|
-
self.gate_proj.weight,
|
|
123
|
-
self.up_proj.weight,
|
|
124
|
-
self.down_proj.weight,
|
|
125
|
-
]
|
|
117
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
126
118
|
|
|
127
119
|
return apply_tiled_mlp(
|
|
128
120
|
fn=self._mlp_forward,
|