liger-kernel-nightly 0.5.9.dev20250519025610__py3-none-any.whl → 0.5.9.dev20250519035525__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
35
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
36
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
37
 
38
+ try:
39
+ import peft
40
+
41
+ PEFT_AVAILABLE = True
42
+ except ImportError:
43
+ PEFT_AVAILABLE = False
44
+
38
45
  transformer_version = version.parse(transformers.__version__)
39
46
 
40
47
  logger = logging.getLogger(__name__)
@@ -48,22 +55,68 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
48
55
 
49
56
 
50
57
  def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
58
+ # Check if the module is a PEFT ModulesToSaveWrapper
59
+ # If it is, we need to patch the modules_to_save.default and original_modules
60
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
61
+ module.modules_to_save.default.offset = offset
62
+ module.modules_to_save.default.casting_mode = casting_mode
63
+ module.modules_to_save.default.variance_epsilon = (
64
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
+ )
66
+ module.modules_to_save.default.in_place = in_place
67
+ module.original_module.offset = offset
68
+ module.original_module.casting_mode = casting_mode
69
+ module.original_module.variance_epsilon = (
70
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
+ )
72
+ module.original_module.in_place = in_place
73
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
76
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
77
+ module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
78
+ module.original_module.__class__.__name__ = LigerRMSNorm.__name__
79
+ else:
80
+ module.offset = offset
81
+ module.casting_mode = casting_mode
82
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
+ module.in_place = in_place
84
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
+ module.__class__.__name__ = LigerRMSNorm.__name__
58
87
 
59
88
 
60
89
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
90
+ # Check if the module is a PEFT ModulesToSaveWrapper
91
+ # If it is, we need to patch the modules_to_save.default and original_modules
92
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
93
+ module.hidden_size = module.normalized_shape
94
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
95
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
96
+ module.modules_to_save.default.variance_epsilon = (
97
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
98
+ )
99
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
100
+ module, "normalized_shape", None
101
+ )
102
+ module.original_module.variance_epsilon = (
103
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
104
+ )
105
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
106
+ module, "normalized_shape", None
107
+ )
108
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
109
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
110
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
111
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
112
+ module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
113
+ module.original_module.__class__.__name__ = LigerLayerNorm.__name__
114
+ else:
115
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
116
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
117
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
118
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
119
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
120
 
68
121
 
69
122
  def _patch_swiglu_module(module, liger_module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250519025610
3
+ Version: 0.5.9.dev20250519035525
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -50,7 +50,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
50
50
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
51
51
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
52
52
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
53
- liger_kernel/transformers/monkey_patch.py,sha256=k8WIkx_f3ObG6TjhIiN_4KeOABurB2W7xy7td0ie-W8,71339
53
+ liger_kernel/transformers/monkey_patch.py,sha256=DKv5-4KyXLiVhAJ9WVFv1I1i1DzjaudTrhqx6EVYViU,74505
54
54
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
55
55
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
56
56
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -82,9 +82,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
82
82
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
83
83
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
84
84
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
85
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
86
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/METADATA,sha256=y96ZmoWt54lwSvXqmZylo4V_wUHZ2dD2Xb29tV0jvLA,23970
87
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
88
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
89
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
90
- liger_kernel_nightly-0.5.9.dev20250519025610.dist-info/RECORD,,
85
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
86
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/METADATA,sha256=_eK-bGVg1jGKnvDjZ_ds17cMG2XCdfcBr3M73rlL5xI,23970
87
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
88
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
89
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
90
+ liger_kernel_nightly-0.5.9.dev20250519035525.dist-info/RECORD,,