liger-kernel 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/kto_loss.py +9 -9
- liger_kernel/ops/cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_jsd.py +3 -3
- liger_kernel/ops/jsd.py +3 -3
- liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +1 -2
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/cross_entropy.py +3 -3
- liger_kernel/transformers/functional.py +17 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/monkey_patch.py +171 -27
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +17 -3
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/RECORD +26 -21
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -61,6 +61,85 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
61
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
def apply_liger_kernel_to_granite(
|
|
65
|
+
rope: bool = True,
|
|
66
|
+
cross_entropy: bool = True,
|
|
67
|
+
fused_linear_cross_entropy: bool = False,
|
|
68
|
+
rms_norm: bool = True,
|
|
69
|
+
swiglu: bool = True,
|
|
70
|
+
model: PreTrainedModel = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
77
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
78
|
+
fused_linear_cross_entropy (bool):
|
|
79
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
80
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
81
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
82
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
83
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
84
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
85
|
+
loaded. Default is None.
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
Debugging notes:
|
|
90
|
+
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
94
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
from transformers.models.granite import modeling_granite
|
|
98
|
+
from transformers.models.granite.modeling_granite import GraniteModel
|
|
99
|
+
|
|
100
|
+
if swiglu:
|
|
101
|
+
modeling_granite.GraniteMLP = LigerSwiGLUMLP
|
|
102
|
+
|
|
103
|
+
if rms_norm:
|
|
104
|
+
modeling_granite.GraniteRMSNorm = LigerRMSNorm
|
|
105
|
+
|
|
106
|
+
if rope:
|
|
107
|
+
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
108
|
+
|
|
109
|
+
if cross_entropy:
|
|
110
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
111
|
+
from transformers.loss.loss_utils import nn
|
|
112
|
+
|
|
113
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
114
|
+
else:
|
|
115
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
116
|
+
modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
117
|
+
|
|
118
|
+
if fused_linear_cross_entropy:
|
|
119
|
+
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
|
|
120
|
+
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
|
|
121
|
+
# call, so we can't sidestep logit materialization. A bit more work
|
|
122
|
+
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
|
|
123
|
+
# for the logit output.
|
|
124
|
+
|
|
125
|
+
if model is not None:
|
|
126
|
+
# The model instance already exists, so we need to additionally patch the
|
|
127
|
+
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
|
|
128
|
+
|
|
129
|
+
# get the base model from the model instance
|
|
130
|
+
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
|
|
131
|
+
|
|
132
|
+
if rms_norm:
|
|
133
|
+
_patch_rms_norm_module(base_model.norm)
|
|
134
|
+
|
|
135
|
+
for decoder_layer in base_model.layers:
|
|
136
|
+
if swiglu:
|
|
137
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
138
|
+
if rms_norm:
|
|
139
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
140
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
141
|
+
|
|
142
|
+
|
|
64
143
|
def apply_liger_kernel_to_llama(
|
|
65
144
|
rope: bool = True,
|
|
66
145
|
cross_entropy: bool = False,
|
|
@@ -85,9 +164,9 @@ def apply_liger_kernel_to_llama(
|
|
|
85
164
|
loaded. Default is None.
|
|
86
165
|
"""
|
|
87
166
|
|
|
88
|
-
assert not (
|
|
89
|
-
cross_entropy and fused_linear_cross_entropy
|
|
90
|
-
)
|
|
167
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
168
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
169
|
+
)
|
|
91
170
|
|
|
92
171
|
from transformers.models.llama import modeling_llama
|
|
93
172
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
@@ -159,9 +238,9 @@ def apply_liger_kernel_to_mllama(
|
|
|
159
238
|
loaded. Default is None.
|
|
160
239
|
"""
|
|
161
240
|
|
|
162
|
-
assert not (
|
|
163
|
-
cross_entropy and fused_linear_cross_entropy
|
|
164
|
-
)
|
|
241
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
242
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
243
|
+
)
|
|
165
244
|
|
|
166
245
|
from transformers.models.mllama import modeling_mllama
|
|
167
246
|
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
|
@@ -261,9 +340,9 @@ def apply_liger_kernel_to_mistral(
|
|
|
261
340
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
262
341
|
loaded. Default is None.
|
|
263
342
|
"""
|
|
264
|
-
assert not (
|
|
265
|
-
cross_entropy and fused_linear_cross_entropy
|
|
266
|
-
)
|
|
343
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
344
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
345
|
+
)
|
|
267
346
|
|
|
268
347
|
from transformers.models.mistral import modeling_mistral
|
|
269
348
|
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
@@ -321,9 +400,9 @@ def apply_liger_kernel_to_mixtral(
|
|
|
321
400
|
loaded. Default is None.
|
|
322
401
|
"""
|
|
323
402
|
|
|
324
|
-
assert not (
|
|
325
|
-
cross_entropy and fused_linear_cross_entropy
|
|
326
|
-
)
|
|
403
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
404
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
405
|
+
)
|
|
327
406
|
|
|
328
407
|
from transformers.models.mixtral import modeling_mixtral
|
|
329
408
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
@@ -393,9 +472,9 @@ def apply_liger_kernel_to_gemma(
|
|
|
393
472
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
394
473
|
loaded. Default is None.
|
|
395
474
|
"""
|
|
396
|
-
assert not (
|
|
397
|
-
cross_entropy and fused_linear_cross_entropy
|
|
398
|
-
)
|
|
475
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
476
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
477
|
+
)
|
|
399
478
|
|
|
400
479
|
from transformers.models.gemma import modeling_gemma
|
|
401
480
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
@@ -467,9 +546,9 @@ def apply_liger_kernel_to_gemma2(
|
|
|
467
546
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
468
547
|
loaded. Default is None.
|
|
469
548
|
"""
|
|
470
|
-
assert not (
|
|
471
|
-
cross_entropy and fused_linear_cross_entropy
|
|
472
|
-
)
|
|
549
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
550
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
551
|
+
)
|
|
473
552
|
|
|
474
553
|
from transformers.models.gemma2 import modeling_gemma2
|
|
475
554
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
@@ -544,9 +623,9 @@ def apply_liger_kernel_to_qwen2(
|
|
|
544
623
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
545
624
|
loaded. Default is None.
|
|
546
625
|
"""
|
|
547
|
-
assert not (
|
|
548
|
-
cross_entropy and fused_linear_cross_entropy
|
|
549
|
-
)
|
|
626
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
627
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
628
|
+
)
|
|
550
629
|
|
|
551
630
|
from transformers.models.qwen2 import modeling_qwen2
|
|
552
631
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
@@ -619,9 +698,9 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
619
698
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
620
699
|
loaded. Default is None.
|
|
621
700
|
"""
|
|
622
|
-
assert not (
|
|
623
|
-
cross_entropy and fused_linear_cross_entropy
|
|
624
|
-
)
|
|
701
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
702
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
703
|
+
)
|
|
625
704
|
|
|
626
705
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
627
706
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
@@ -689,9 +768,9 @@ def apply_liger_kernel_to_phi3(
|
|
|
689
768
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
690
769
|
loaded. Default is None.
|
|
691
770
|
"""
|
|
692
|
-
assert not (
|
|
693
|
-
cross_entropy and fused_linear_cross_entropy
|
|
694
|
-
)
|
|
771
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
772
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
773
|
+
)
|
|
695
774
|
|
|
696
775
|
from transformers.models.phi3 import modeling_phi3
|
|
697
776
|
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
@@ -735,15 +814,80 @@ def apply_liger_kernel_to_phi3(
|
|
|
735
814
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
736
815
|
|
|
737
816
|
|
|
817
|
+
def apply_liger_kernel_to_olmo2(
|
|
818
|
+
rope: bool = True,
|
|
819
|
+
cross_entropy: bool = False,
|
|
820
|
+
fused_linear_cross_entropy: bool = True,
|
|
821
|
+
rms_norm: bool = True,
|
|
822
|
+
swiglu: bool = True,
|
|
823
|
+
model: PreTrainedModel = None,
|
|
824
|
+
) -> None:
|
|
825
|
+
"""
|
|
826
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
827
|
+
|
|
828
|
+
Args:
|
|
829
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
830
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
831
|
+
fused_linear_cross_entropy (bool):
|
|
832
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
833
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
834
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
835
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
836
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
837
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
838
|
+
loaded. Default is None.
|
|
839
|
+
"""
|
|
840
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
841
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
845
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
846
|
+
|
|
847
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
848
|
+
|
|
849
|
+
if rope:
|
|
850
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
851
|
+
if rms_norm:
|
|
852
|
+
modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
|
|
853
|
+
if swiglu:
|
|
854
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
855
|
+
if cross_entropy:
|
|
856
|
+
from transformers.loss.loss_utils import nn
|
|
857
|
+
|
|
858
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
859
|
+
if fused_linear_cross_entropy:
|
|
860
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
861
|
+
|
|
862
|
+
if model is not None:
|
|
863
|
+
# The model instance already exists, so we need to additionally patch the
|
|
864
|
+
# instance variables that reference already-instantiated modules
|
|
865
|
+
|
|
866
|
+
# get the base model from the model instance
|
|
867
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
868
|
+
|
|
869
|
+
if rms_norm:
|
|
870
|
+
_patch_rms_norm_module(base_model.norm)
|
|
871
|
+
|
|
872
|
+
for decoder_layer in base_model.layers:
|
|
873
|
+
if swiglu:
|
|
874
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
875
|
+
if rms_norm:
|
|
876
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
877
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
878
|
+
|
|
879
|
+
|
|
738
880
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
739
881
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
740
882
|
"gemma": apply_liger_kernel_to_gemma,
|
|
741
883
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
742
884
|
"llama": apply_liger_kernel_to_llama,
|
|
885
|
+
"granite": apply_liger_kernel_to_granite,
|
|
743
886
|
"mllama": apply_liger_kernel_to_mllama,
|
|
744
887
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
745
888
|
"mistral": apply_liger_kernel_to_mistral,
|
|
746
889
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
890
|
+
"olmo2": apply_liger_kernel_to_olmo2,
|
|
747
891
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
748
892
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
749
893
|
"phi3": apply_liger_kernel_to_phi3,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerTVDLoss(nn.Module):
|
|
7
|
+
def __init__(self, reduction="batchmean", ignore_index: int = -100):
|
|
8
|
+
super(LigerTVDLoss, self).__init__()
|
|
9
|
+
self.reduction = reduction
|
|
10
|
+
self.ignore_index = ignore_index
|
|
11
|
+
|
|
12
|
+
def forward(self, p, q, shift_labels=None):
|
|
13
|
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
liger_kernel/utils.py
CHANGED
|
@@ -9,5 +9,54 @@ def infer_device():
|
|
|
9
9
|
return "cuda"
|
|
10
10
|
elif torch.xpu.is_available():
|
|
11
11
|
return "xpu"
|
|
12
|
+
elif torch.hip.is_available():
|
|
13
|
+
return "hip"
|
|
12
14
|
else:
|
|
13
15
|
return "cpu"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformers_version_dispatch(
|
|
19
|
+
required_version: str,
|
|
20
|
+
before_fn,
|
|
21
|
+
after_fn,
|
|
22
|
+
before_args: tuple = (),
|
|
23
|
+
after_args: tuple = (),
|
|
24
|
+
before_kwargs: dict = None,
|
|
25
|
+
after_kwargs: dict = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Dispatches to different functions based on package version comparison.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
required_version: Version to compare against (e.g. "4.48.0")
|
|
32
|
+
before_fn: Function to call if package_version < required_version
|
|
33
|
+
after_fn: Function to call if package_version >= required_version
|
|
34
|
+
before_args: Positional arguments for before_fn
|
|
35
|
+
after_args: Positional arguments for after_fn
|
|
36
|
+
before_kwargs: Keyword arguments for before_fn
|
|
37
|
+
after_kwargs: Keyword arguments for after_fn
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Result from either before_fn or after_fn
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> rotary_emb = transformers_version_dispatch(
|
|
44
|
+
... "4.48.0",
|
|
45
|
+
... LlamaRotaryEmbedding,
|
|
46
|
+
... LlamaRotaryEmbedding,
|
|
47
|
+
... before_args=(head_dim,),
|
|
48
|
+
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
49
|
+
... before_kwargs={'device': device},
|
|
50
|
+
... after_kwargs={'device': device}
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
from packaging import version
|
|
54
|
+
from transformers import __version__ as transformers_version
|
|
55
|
+
|
|
56
|
+
before_kwargs = before_kwargs or {}
|
|
57
|
+
after_kwargs = after_kwargs or {}
|
|
58
|
+
|
|
59
|
+
if version.parse(transformers_version) < version.parse(required_version):
|
|
60
|
+
return before_fn(*before_args, **before_kwargs)
|
|
61
|
+
else:
|
|
62
|
+
return after_fn(*after_args, **after_kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -97,6 +97,11 @@ Dynamic: requires-dist
|
|
|
97
97
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
98
98
|
</a>
|
|
99
99
|
</div>
|
|
100
|
+
<div style="display: block;">
|
|
101
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
102
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
103
|
+
</a>
|
|
104
|
+
</div>
|
|
100
105
|
</td>
|
|
101
106
|
</tr>
|
|
102
107
|
</table>
|
|
@@ -123,7 +128,7 @@ Dynamic: requires-dist
|
|
|
123
128
|
|
|
124
129
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
125
130
|
|
|
126
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
131
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
127
132
|
|
|
128
133
|
## Supercharge Your Model with Liger Kernel
|
|
129
134
|
|
|
@@ -188,6 +193,11 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
188
193
|
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
189
194
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
190
195
|
|
|
196
|
+
```bash
|
|
197
|
+
# Need to pass the url when installing
|
|
198
|
+
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
|
199
|
+
```
|
|
200
|
+
|
|
191
201
|
### Optional Dependencies
|
|
192
202
|
|
|
193
203
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -305,6 +315,8 @@ loss.backward()
|
|
|
305
315
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
306
316
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
307
317
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
318
|
+
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
319
|
+
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
308
320
|
|
|
309
321
|
|
|
310
322
|
## Low-level APIs
|
|
@@ -333,6 +345,7 @@ loss.backward()
|
|
|
333
345
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
334
346
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
335
347
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
348
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
|
336
349
|
|
|
337
350
|
### Distillation Kernels
|
|
338
351
|
|
|
@@ -341,6 +354,7 @@ loss.backward()
|
|
|
341
354
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
342
355
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
343
356
|
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
357
|
+
| TVD | `liger_kernel.transformers.LigerTVDLoss` |
|
|
344
358
|
|
|
345
359
|
### Experimental Kernels
|
|
346
360
|
|
|
@@ -372,7 +386,7 @@ loss.backward()
|
|
|
372
386
|
|
|
373
387
|
- For issues, create a Github ticket in this repository
|
|
374
388
|
- For open discussion, join [our discord channel](https://discord.gg/gpumode)
|
|
375
|
-
- For formal collaboration, send an email to
|
|
389
|
+
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
376
390
|
|
|
377
391
|
## Cite this work
|
|
378
392
|
|
|
@@ -1,51 +1,55 @@
|
|
|
1
1
|
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
|
3
|
-
liger_kernel/utils.py,sha256=
|
|
3
|
+
liger_kernel/utils.py,sha256=FtVUkCGBT1UNasTl6HMNycWwiwHayK6tx-ZDdA-sNX4,1884
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=
|
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=wgjnwzLfrMUwV5mXgrq6G1YfQKWnbiFJegaP48BGJHY,4509
|
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
|
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=5V8rdva89WyHVbmJ8JOmC4DYNOR6ByXfx3qlUieOZkI,11002
|
|
10
10
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
|
|
11
|
+
liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFfBgFcWF5LB4Gzd7LgY,8044
|
|
11
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
|
|
13
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
|
|
12
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
|
|
13
|
-
liger_kernel/chunked_loss/kto_loss.py,sha256=
|
|
15
|
+
liger_kernel/chunked_loss/kto_loss.py,sha256=b3ffJyk97e-6XdXd4HFrYyx8wW4A-CU4gOaJSimKLtA,5476
|
|
14
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
|
15
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
|
16
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
|
18
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
|
19
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=
|
|
19
|
+
liger_kernel/ops/cross_entropy.py,sha256=D6vFFloiuxFXoWfjlIjmfO3tVaWOiYmztw9FKAi5vdU,18608
|
|
20
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
|
|
21
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
|
|
20
22
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
|
21
23
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
|
22
|
-
liger_kernel/ops/jsd.py,sha256=
|
|
24
|
+
liger_kernel/ops/jsd.py,sha256=0jNeRxpcNI5ckxCdoCNyO5GEedLIuzx3lz6KAiksc4o,6109
|
|
23
25
|
liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
|
|
24
|
-
liger_kernel/ops/layer_norm.py,sha256=
|
|
26
|
+
liger_kernel/ops/layer_norm.py,sha256=6roQjioyg-9O2qLPV8nL4U0-5UH80tdzOMTWwjvDnn8,7961
|
|
25
27
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
26
28
|
liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
|
|
27
29
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
|
28
30
|
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
|
29
|
-
liger_kernel/ops/
|
|
31
|
+
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
32
|
+
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
|
30
33
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
31
34
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
32
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
35
|
+
liger_kernel/transformers/__init__.py,sha256=6v_VcV1GQ9ISgNCd-ZxtmEg_s5GTBQ9F-s1KrFkYzPQ,2265
|
|
33
36
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
34
|
-
liger_kernel/transformers/cross_entropy.py,sha256=
|
|
35
|
-
liger_kernel/transformers/functional.py,sha256=
|
|
36
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=
|
|
37
|
+
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
|
38
|
+
liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
|
|
39
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
|
|
37
40
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
|
38
41
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
|
39
|
-
liger_kernel/transformers/group_norm.py,sha256=
|
|
42
|
+
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
|
40
43
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
|
41
44
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
|
42
45
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
|
43
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
46
|
+
liger_kernel/transformers/monkey_patch.py,sha256=g3i3q5McBg23A3Mnviw-Eb32le1hvN7jByzONa9ngcs,44000
|
|
44
47
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
|
45
48
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
|
46
49
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
|
47
50
|
liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
|
|
48
51
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
|
52
|
+
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
|
49
53
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
|
50
54
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
55
|
liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
|
|
@@ -54,6 +58,7 @@ liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSM
|
|
|
54
58
|
liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
|
|
55
59
|
liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
|
|
56
60
|
liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
|
|
61
|
+
liger_kernel/transformers/model/olmo2.py,sha256=yyksS6E4fuWd8asEW8rEDBKqZpFmP4ITCM_bjIDZaoY,5124
|
|
57
62
|
liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
|
|
58
63
|
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
|
59
64
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=yMLqsfSYcvhClUpTUjGoADiOxfLB2B8240VdrPP0c8s,9851
|
|
@@ -61,9 +66,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
61
66
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
62
67
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
63
68
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
64
|
-
liger_kernel-0.5.
|
|
65
|
-
liger_kernel-0.5.
|
|
66
|
-
liger_kernel-0.5.
|
|
67
|
-
liger_kernel-0.5.
|
|
68
|
-
liger_kernel-0.5.
|
|
69
|
-
liger_kernel-0.5.
|
|
69
|
+
liger_kernel-0.5.4.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
70
|
+
liger_kernel-0.5.4.dist-info/METADATA,sha256=Zw7n3Ey6vUed4E54H9-TzKmhuOpd9P2ZFMVL-zYUnew,22255
|
|
71
|
+
liger_kernel-0.5.4.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
72
|
+
liger_kernel-0.5.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
73
|
+
liger_kernel-0.5.4.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
74
|
+
liger_kernel-0.5.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|