liger-kernel-nightly 0.5.10.dev20250629005644__py3-none-any.whl → 0.5.10.dev20250630172023__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/monkey_patch.py +113 -31
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250629005644.dist-info → liger_kernel_nightly-0.5.10.dev20250630172023.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
|
4
4
|
from functools import partial
|
5
|
+
from types import MethodType
|
5
6
|
from typing import Callable
|
6
7
|
|
7
8
|
import transformers
|
@@ -260,10 +261,16 @@ def apply_liger_kernel_to_llama(
|
|
260
261
|
|
261
262
|
if fused_linear_cross_entropy:
|
262
263
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
263
|
-
|
264
|
+
if model is not None:
|
265
|
+
model.forward = MethodType(llama_lce_forward, model)
|
266
|
+
else:
|
267
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
264
268
|
else: # if version < 4.46.1
|
265
269
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
266
|
-
|
270
|
+
if model is not None:
|
271
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
272
|
+
else:
|
273
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
267
274
|
|
268
275
|
if model is not None:
|
269
276
|
# The model instance already exists, so we need to additionally patch the
|
@@ -318,9 +325,15 @@ def apply_liger_kernel_to_llava(
|
|
318
325
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
319
326
|
if fused_linear_cross_entropy:
|
320
327
|
if transformer_version >= version.parse("4.52.0"):
|
321
|
-
|
328
|
+
if model is not None:
|
329
|
+
model.forward = MethodType(llava_lce_forward, model)
|
330
|
+
else:
|
331
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
322
332
|
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
323
|
-
|
333
|
+
if model is not None:
|
334
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
335
|
+
else:
|
336
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
324
337
|
else: # if version < 4.49.0
|
325
338
|
logger.warning(
|
326
339
|
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
@@ -490,7 +503,7 @@ def apply_liger_kernel_to_mllama(
|
|
490
503
|
|
491
504
|
if rope:
|
492
505
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
493
|
-
if layer_norm:
|
506
|
+
if layer_norm and model is None:
|
494
507
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
495
508
|
if rms_norm:
|
496
509
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
@@ -506,10 +519,16 @@ def apply_liger_kernel_to_mllama(
|
|
506
519
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
507
520
|
if fused_linear_cross_entropy:
|
508
521
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
509
|
-
|
522
|
+
if model is not None:
|
523
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
524
|
+
else:
|
525
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
510
526
|
else: # if version < 4.46.1
|
511
527
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
512
|
-
|
528
|
+
if model is not None:
|
529
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
530
|
+
else:
|
531
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
513
532
|
|
514
533
|
if model is not None:
|
515
534
|
# The model instance already exists, so we need to additionally patch the
|
@@ -592,7 +611,10 @@ def apply_liger_kernel_to_mistral(
|
|
592
611
|
if cross_entropy:
|
593
612
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
594
613
|
if fused_linear_cross_entropy:
|
595
|
-
|
614
|
+
if model is not None:
|
615
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
616
|
+
else:
|
617
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
596
618
|
if swiglu:
|
597
619
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
598
620
|
|
@@ -660,10 +682,16 @@ def apply_liger_kernel_to_mixtral(
|
|
660
682
|
|
661
683
|
if fused_linear_cross_entropy:
|
662
684
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
663
|
-
|
685
|
+
if model is not None:
|
686
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
687
|
+
else:
|
688
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
664
689
|
else: # if version < 4.46.1
|
665
690
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
666
|
-
|
691
|
+
if model is not None:
|
692
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
693
|
+
else:
|
694
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
667
695
|
if swiglu:
|
668
696
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
669
697
|
|
@@ -737,10 +765,16 @@ def apply_liger_kernel_to_gemma(
|
|
737
765
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
738
766
|
if fused_linear_cross_entropy:
|
739
767
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
740
|
-
|
768
|
+
if model is not None:
|
769
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
770
|
+
else:
|
771
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
741
772
|
else: # if version < 4.46.1
|
742
773
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
743
|
-
|
774
|
+
if model is not None:
|
775
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
776
|
+
else:
|
777
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
744
778
|
|
745
779
|
if model is not None:
|
746
780
|
# The model instance already exists, so we need to additionally patch the
|
@@ -812,10 +846,16 @@ def apply_liger_kernel_to_gemma2(
|
|
812
846
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
813
847
|
if fused_linear_cross_entropy:
|
814
848
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
815
|
-
|
849
|
+
if model is not None:
|
850
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
851
|
+
else:
|
852
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
816
853
|
else:
|
817
854
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
818
|
-
|
855
|
+
if model is not None:
|
856
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
857
|
+
else:
|
858
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
819
859
|
if geglu:
|
820
860
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
821
861
|
|
@@ -894,7 +934,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
894
934
|
nn.functional.cross_entropy = liger_cross_entropy
|
895
935
|
|
896
936
|
if fused_linear_cross_entropy:
|
897
|
-
|
937
|
+
if model is not None:
|
938
|
+
model.forward = MethodType(causal_forward, model)
|
939
|
+
else:
|
940
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
898
941
|
|
899
942
|
if model is not None:
|
900
943
|
# The model instance already exists, so we need to additionally patch the
|
@@ -964,7 +1007,7 @@ def apply_liger_kernel_to_gemma3(
|
|
964
1007
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
965
1008
|
)
|
966
1009
|
|
967
|
-
if layer_norm:
|
1010
|
+
if layer_norm and model is None:
|
968
1011
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
969
1012
|
|
970
1013
|
apply_liger_kernel_to_gemma3_text(
|
@@ -975,7 +1018,10 @@ def apply_liger_kernel_to_gemma3(
|
|
975
1018
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
976
1019
|
|
977
1020
|
if fused_linear_cross_entropy:
|
978
|
-
|
1021
|
+
if model is not None:
|
1022
|
+
model.forward = MethodType(multimodal_forward, model)
|
1023
|
+
else:
|
1024
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
979
1025
|
|
980
1026
|
if model is not None:
|
981
1027
|
# The model instance already exists, so we need to additionally patch the
|
@@ -1054,7 +1100,7 @@ def apply_liger_kernel_to_paligemma(
|
|
1054
1100
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
1055
1101
|
|
1056
1102
|
# The vision_tower is a SiglipVisionModel
|
1057
|
-
if layer_norm:
|
1103
|
+
if layer_norm and model is None:
|
1058
1104
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
1059
1105
|
|
1060
1106
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
@@ -1072,10 +1118,16 @@ def apply_liger_kernel_to_paligemma(
|
|
1072
1118
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
1073
1119
|
if fused_linear_cross_entropy:
|
1074
1120
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
1075
|
-
|
1121
|
+
if model is not None:
|
1122
|
+
model.forward = MethodType(lce_forward, model)
|
1123
|
+
else:
|
1124
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
1076
1125
|
else: # if version < 4.46.1
|
1077
1126
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
1078
|
-
|
1127
|
+
if model is not None:
|
1128
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
1129
|
+
else:
|
1130
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
1079
1131
|
|
1080
1132
|
if model is not None:
|
1081
1133
|
# The model instance already exists, so we need to additionally patch the
|
@@ -1167,10 +1219,16 @@ def apply_liger_kernel_to_qwen2(
|
|
1167
1219
|
|
1168
1220
|
if fused_linear_cross_entropy:
|
1169
1221
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
1170
|
-
|
1222
|
+
if model is not None:
|
1223
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
1224
|
+
else:
|
1225
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
1171
1226
|
else: # if version < 4.46.1
|
1172
1227
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
1173
|
-
|
1228
|
+
if model is not None:
|
1229
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
1230
|
+
else:
|
1231
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
1174
1232
|
|
1175
1233
|
if swiglu:
|
1176
1234
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
@@ -1226,7 +1284,10 @@ def apply_liger_kernel_to_qwen3(
|
|
1226
1284
|
nn.functional.cross_entropy = liger_cross_entropy
|
1227
1285
|
|
1228
1286
|
if fused_linear_cross_entropy:
|
1229
|
-
|
1287
|
+
if model is not None:
|
1288
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
1289
|
+
else:
|
1290
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
1230
1291
|
|
1231
1292
|
if swiglu:
|
1232
1293
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
@@ -1281,7 +1342,10 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
1281
1342
|
nn.functional.cross_entropy = liger_cross_entropy
|
1282
1343
|
|
1283
1344
|
if fused_linear_cross_entropy:
|
1284
|
-
|
1345
|
+
if model is not None:
|
1346
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
1347
|
+
else:
|
1348
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
1285
1349
|
|
1286
1350
|
if swiglu:
|
1287
1351
|
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
@@ -1350,12 +1414,15 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
1350
1414
|
if rms_norm:
|
1351
1415
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
1352
1416
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
1353
|
-
if layer_norm:
|
1417
|
+
if layer_norm and model is None:
|
1354
1418
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
1355
1419
|
if cross_entropy:
|
1356
1420
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
1357
1421
|
if fused_linear_cross_entropy:
|
1358
|
-
|
1422
|
+
if model is not None:
|
1423
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
1424
|
+
else:
|
1425
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
1359
1426
|
if swiglu:
|
1360
1427
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
1361
1428
|
|
@@ -1443,7 +1510,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
1443
1510
|
if cross_entropy:
|
1444
1511
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
1445
1512
|
if fused_linear_cross_entropy:
|
1446
|
-
|
1513
|
+
if model is not None:
|
1514
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
1515
|
+
else:
|
1516
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
1447
1517
|
if swiglu:
|
1448
1518
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
1449
1519
|
|
@@ -1530,10 +1600,16 @@ def apply_liger_kernel_to_phi3(
|
|
1530
1600
|
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
1531
1601
|
if fused_linear_cross_entropy:
|
1532
1602
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
1533
|
-
|
1603
|
+
if model is not None:
|
1604
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
1605
|
+
else:
|
1606
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
1534
1607
|
else: # if version < 4.46.1
|
1535
1608
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
1536
|
-
|
1609
|
+
if model is not None:
|
1610
|
+
model.forward = MethodType(phi3_lce_forward_deprecated, model)
|
1611
|
+
else:
|
1612
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
1537
1613
|
|
1538
1614
|
if model is not None:
|
1539
1615
|
# The model instance already exists, so we need to additionally patch the
|
@@ -1597,7 +1673,10 @@ def apply_liger_kernel_to_olmo2(
|
|
1597
1673
|
|
1598
1674
|
nn.functional.cross_entropy = liger_cross_entropy
|
1599
1675
|
if fused_linear_cross_entropy:
|
1600
|
-
|
1676
|
+
if model is not None:
|
1677
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
1678
|
+
else:
|
1679
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
1601
1680
|
|
1602
1681
|
if model is not None:
|
1603
1682
|
# The model instance already exists, so we need to additionally patch the
|
@@ -1661,7 +1740,10 @@ def apply_liger_kernel_to_glm4(
|
|
1661
1740
|
|
1662
1741
|
nn.functional.cross_entropy = liger_cross_entropy
|
1663
1742
|
if fused_linear_cross_entropy:
|
1664
|
-
|
1743
|
+
if model is not None:
|
1744
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
1745
|
+
else:
|
1746
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
1665
1747
|
|
1666
1748
|
if model is not None:
|
1667
1749
|
# The model instance already exists, so we need to additionally patch the
|
@@ -53,7 +53,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
|
|
53
53
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
54
54
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
55
55
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
56
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
56
|
+
liger_kernel/transformers/monkey_patch.py,sha256=YkX0LT6lISg3UTqFjjt9kTr36WgiHvYTQObAS1_Bmi4,85172
|
57
57
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
58
58
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
59
59
|
liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
|
@@ -88,9 +88,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
88
88
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
89
89
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
90
90
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
95
|
-
liger_kernel_nightly-0.5.10.
|
96
|
-
liger_kernel_nightly-0.5.10.
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/METADATA,sha256=R9S054XUfsyrq9HECn8SHjNLRdXF6KxS6vP1w_fuqjI,24536
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
96
|
+
liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|