liger-kernel-nightly 0.5.10.dev20250629005644__py3-none-any.whl → 0.5.10.dev20250630172023__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
6
7
 
7
8
  import transformers
@@ -260,10 +261,16 @@ def apply_liger_kernel_to_llama(
260
261
 
261
262
  if fused_linear_cross_entropy:
262
263
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
263
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264
+ if model is not None:
265
+ model.forward = MethodType(llama_lce_forward, model)
266
+ else:
267
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264
268
  else: # if version < 4.46.1
265
269
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
266
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
270
+ if model is not None:
271
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
272
+ else:
273
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
267
274
 
268
275
  if model is not None:
269
276
  # The model instance already exists, so we need to additionally patch the
@@ -318,9 +325,15 @@ def apply_liger_kernel_to_llava(
318
325
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
319
326
  if fused_linear_cross_entropy:
320
327
  if transformer_version >= version.parse("4.52.0"):
321
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
328
+ if model is not None:
329
+ model.forward = MethodType(llava_lce_forward, model)
330
+ else:
331
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
322
332
  elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
323
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
333
+ if model is not None:
334
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
335
+ else:
336
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324
337
  else: # if version < 4.49.0
325
338
  logger.warning(
326
339
  "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
@@ -490,7 +503,7 @@ def apply_liger_kernel_to_mllama(
490
503
 
491
504
  if rope:
492
505
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
493
- if layer_norm:
506
+ if layer_norm and model is None:
494
507
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
495
508
  if rms_norm:
496
509
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -506,10 +519,16 @@ def apply_liger_kernel_to_mllama(
506
519
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
507
520
  if fused_linear_cross_entropy:
508
521
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
509
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
522
+ if model is not None:
523
+ model.forward = MethodType(mllama_lce_forward, model)
524
+ else:
525
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
510
526
  else: # if version < 4.46.1
511
527
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
512
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
528
+ if model is not None:
529
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
530
+ else:
531
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
513
532
 
514
533
  if model is not None:
515
534
  # The model instance already exists, so we need to additionally patch the
@@ -592,7 +611,10 @@ def apply_liger_kernel_to_mistral(
592
611
  if cross_entropy:
593
612
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
594
613
  if fused_linear_cross_entropy:
595
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
614
+ if model is not None:
615
+ model.forward = MethodType(mistral_lce_forward, model)
616
+ else:
617
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
596
618
  if swiglu:
597
619
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
598
620
 
@@ -660,10 +682,16 @@ def apply_liger_kernel_to_mixtral(
660
682
 
661
683
  if fused_linear_cross_entropy:
662
684
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
663
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
685
+ if model is not None:
686
+ model.forward = MethodType(mixtral_lce_forward, model)
687
+ else:
688
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
664
689
  else: # if version < 4.46.1
665
690
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
666
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
691
+ if model is not None:
692
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
693
+ else:
694
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
667
695
  if swiglu:
668
696
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
669
697
 
@@ -737,10 +765,16 @@ def apply_liger_kernel_to_gemma(
737
765
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
738
766
  if fused_linear_cross_entropy:
739
767
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
740
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
768
+ if model is not None:
769
+ model.forward = MethodType(gemma_lce_forward, model)
770
+ else:
771
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
741
772
  else: # if version < 4.46.1
742
773
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
743
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
774
+ if model is not None:
775
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
776
+ else:
777
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
744
778
 
745
779
  if model is not None:
746
780
  # The model instance already exists, so we need to additionally patch the
@@ -812,10 +846,16 @@ def apply_liger_kernel_to_gemma2(
812
846
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
813
847
  if fused_linear_cross_entropy:
814
848
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
815
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
849
+ if model is not None:
850
+ model.forward = MethodType(gemma2_lce_forward, model)
851
+ else:
852
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
816
853
  else:
817
854
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
818
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
855
+ if model is not None:
856
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
857
+ else:
858
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
819
859
  if geglu:
820
860
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
821
861
 
@@ -894,7 +934,10 @@ def apply_liger_kernel_to_gemma3_text(
894
934
  nn.functional.cross_entropy = liger_cross_entropy
895
935
 
896
936
  if fused_linear_cross_entropy:
897
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
937
+ if model is not None:
938
+ model.forward = MethodType(causal_forward, model)
939
+ else:
940
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
898
941
 
899
942
  if model is not None:
900
943
  # The model instance already exists, so we need to additionally patch the
@@ -964,7 +1007,7 @@ def apply_liger_kernel_to_gemma3(
964
1007
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
965
1008
  )
966
1009
 
967
- if layer_norm:
1010
+ if layer_norm and model is None:
968
1011
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
969
1012
 
970
1013
  apply_liger_kernel_to_gemma3_text(
@@ -975,7 +1018,10 @@ def apply_liger_kernel_to_gemma3(
975
1018
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
976
1019
 
977
1020
  if fused_linear_cross_entropy:
978
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1021
+ if model is not None:
1022
+ model.forward = MethodType(multimodal_forward, model)
1023
+ else:
1024
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
979
1025
 
980
1026
  if model is not None:
981
1027
  # The model instance already exists, so we need to additionally patch the
@@ -1054,7 +1100,7 @@ def apply_liger_kernel_to_paligemma(
1054
1100
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
1055
1101
 
1056
1102
  # The vision_tower is a SiglipVisionModel
1057
- if layer_norm:
1103
+ if layer_norm and model is None:
1058
1104
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
1059
1105
 
1060
1106
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -1072,10 +1118,16 @@ def apply_liger_kernel_to_paligemma(
1072
1118
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1073
1119
  if fused_linear_cross_entropy:
1074
1120
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1075
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1121
+ if model is not None:
1122
+ model.forward = MethodType(lce_forward, model)
1123
+ else:
1124
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1076
1125
  else: # if version < 4.46.1
1077
1126
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1078
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1127
+ if model is not None:
1128
+ model.forward = MethodType(lce_forward_deprecated, model)
1129
+ else:
1130
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1079
1131
 
1080
1132
  if model is not None:
1081
1133
  # The model instance already exists, so we need to additionally patch the
@@ -1167,10 +1219,16 @@ def apply_liger_kernel_to_qwen2(
1167
1219
 
1168
1220
  if fused_linear_cross_entropy:
1169
1221
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1170
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1222
+ if model is not None:
1223
+ model.forward = MethodType(qwen2_lce_forward, model)
1224
+ else:
1225
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1171
1226
  else: # if version < 4.46.1
1172
1227
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1173
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1228
+ if model is not None:
1229
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1230
+ else:
1231
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1174
1232
 
1175
1233
  if swiglu:
1176
1234
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1226,7 +1284,10 @@ def apply_liger_kernel_to_qwen3(
1226
1284
  nn.functional.cross_entropy = liger_cross_entropy
1227
1285
 
1228
1286
  if fused_linear_cross_entropy:
1229
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1287
+ if model is not None:
1288
+ model.forward = MethodType(qwen3_lce_forward, model)
1289
+ else:
1290
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1230
1291
 
1231
1292
  if swiglu:
1232
1293
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1281,7 +1342,10 @@ def apply_liger_kernel_to_qwen3_moe(
1281
1342
  nn.functional.cross_entropy = liger_cross_entropy
1282
1343
 
1283
1344
  if fused_linear_cross_entropy:
1284
- modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1345
+ if model is not None:
1346
+ model.forward = MethodType(qwen3_lce_forward, model)
1347
+ else:
1348
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1285
1349
 
1286
1350
  if swiglu:
1287
1351
  modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1350,12 +1414,15 @@ def apply_liger_kernel_to_qwen2_vl(
1350
1414
  if rms_norm:
1351
1415
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1352
1416
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1353
- if layer_norm:
1417
+ if layer_norm and model is None:
1354
1418
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1355
1419
  if cross_entropy:
1356
1420
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1357
1421
  if fused_linear_cross_entropy:
1358
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1422
+ if model is not None:
1423
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1424
+ else:
1425
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1359
1426
  if swiglu:
1360
1427
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1361
1428
 
@@ -1443,7 +1510,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1443
1510
  if cross_entropy:
1444
1511
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1445
1512
  if fused_linear_cross_entropy:
1446
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1513
+ if model is not None:
1514
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1515
+ else:
1516
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1447
1517
  if swiglu:
1448
1518
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1449
1519
 
@@ -1530,10 +1600,16 @@ def apply_liger_kernel_to_phi3(
1530
1600
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1531
1601
  if fused_linear_cross_entropy:
1532
1602
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1533
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1603
+ if model is not None:
1604
+ model.forward = MethodType(phi3_lce_forward, model)
1605
+ else:
1606
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1534
1607
  else: # if version < 4.46.1
1535
1608
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1536
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1609
+ if model is not None:
1610
+ model.forward = MethodType(phi3_lce_forward_deprecated, model)
1611
+ else:
1612
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1537
1613
 
1538
1614
  if model is not None:
1539
1615
  # The model instance already exists, so we need to additionally patch the
@@ -1597,7 +1673,10 @@ def apply_liger_kernel_to_olmo2(
1597
1673
 
1598
1674
  nn.functional.cross_entropy = liger_cross_entropy
1599
1675
  if fused_linear_cross_entropy:
1600
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1676
+ if model is not None:
1677
+ model.forward = MethodType(olmo2_lce_forward, model)
1678
+ else:
1679
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1601
1680
 
1602
1681
  if model is not None:
1603
1682
  # The model instance already exists, so we need to additionally patch the
@@ -1661,7 +1740,10 @@ def apply_liger_kernel_to_glm4(
1661
1740
 
1662
1741
  nn.functional.cross_entropy = liger_cross_entropy
1663
1742
  if fused_linear_cross_entropy:
1664
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1743
+ if model is not None:
1744
+ model.forward = MethodType(glm4_lce_forward, model)
1745
+ else:
1746
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1665
1747
 
1666
1748
  if model is not None:
1667
1749
  # The model instance already exists, so we need to additionally patch the
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250629005644
3
+ Version: 0.5.10.dev20250630172023
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -53,7 +53,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
53
53
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
54
54
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
55
55
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
56
- liger_kernel/transformers/monkey_patch.py,sha256=3KqEl_-WlXgUoEAEYgGs-SPolASshGem2ISFemzQAIc,81705
56
+ liger_kernel/transformers/monkey_patch.py,sha256=YkX0LT6lISg3UTqFjjt9kTr36WgiHvYTQObAS1_Bmi4,85172
57
57
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
58
58
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
59
59
  liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
@@ -88,9 +88,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
88
88
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
89
89
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
90
90
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
91
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
92
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/METADATA,sha256=FMeKbXVH-02gQ_G0kVMIc6ftN9rv5WeQZ94Br45A9ek,24536
93
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
94
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
95
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
96
- liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/RECORD,,
91
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
92
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/METADATA,sha256=R9S054XUfsyrq9HECn8SHjNLRdXF6KxS6vP1w_fuqjI,24536
93
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
94
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
95
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
96
+ liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/RECORD,,