liger-kernel-nightly 0.5.3.dev20250221002845__py3-none-any.whl → 0.5.3.dev20250221011057__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/monkey_patch.py +80 -0
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/METADATA +3 -2
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221002845.dist-info → liger_kernel_nightly-0.5.3.dev20250221011057.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa:
|
|
9
9
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
10
10
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
11
11
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
12
13
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
13
14
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
14
15
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
@@ -61,6 +61,85 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
61
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
62
62
|
|
63
63
|
|
64
|
+
def apply_liger_kernel_to_granite(
|
65
|
+
rope: bool = True,
|
66
|
+
cross_entropy: bool = True,
|
67
|
+
fused_linear_cross_entropy: bool = False,
|
68
|
+
rms_norm: bool = True,
|
69
|
+
swiglu: bool = True,
|
70
|
+
model: PreTrainedModel = None,
|
71
|
+
) -> None:
|
72
|
+
"""
|
73
|
+
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
|
74
|
+
|
75
|
+
Args:
|
76
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
77
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
78
|
+
fused_linear_cross_entropy (bool):
|
79
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
80
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
81
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
82
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
83
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
84
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
85
|
+
loaded. Default is None.
|
86
|
+
|
87
|
+
|
88
|
+
|
89
|
+
Debugging notes:
|
90
|
+
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
|
91
|
+
"""
|
92
|
+
|
93
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
94
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
95
|
+
)
|
96
|
+
|
97
|
+
from transformers.models.granite import modeling_granite
|
98
|
+
from transformers.models.granite.modeling_granite import GraniteModel
|
99
|
+
|
100
|
+
if swiglu:
|
101
|
+
modeling_granite.GraniteMLP = LigerSwiGLUMLP
|
102
|
+
|
103
|
+
if rms_norm:
|
104
|
+
modeling_granite.GraniteRMSNorm = LigerRMSNorm
|
105
|
+
|
106
|
+
if rope:
|
107
|
+
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
|
108
|
+
|
109
|
+
if cross_entropy:
|
110
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
111
|
+
from transformers.loss.loss_utils import nn
|
112
|
+
|
113
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
114
|
+
else:
|
115
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
116
|
+
modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
|
117
|
+
|
118
|
+
if fused_linear_cross_entropy:
|
119
|
+
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
|
120
|
+
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
|
121
|
+
# call, so we can't sidestep logit materialization. A bit more work
|
122
|
+
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
|
123
|
+
# for the logit output.
|
124
|
+
|
125
|
+
if model is not None:
|
126
|
+
# The model instance already exists, so we need to additionally patch the
|
127
|
+
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
|
128
|
+
|
129
|
+
# get the base model from the model instance
|
130
|
+
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
|
131
|
+
|
132
|
+
if rms_norm:
|
133
|
+
_patch_rms_norm_module(base_model.norm)
|
134
|
+
|
135
|
+
for decoder_layer in base_model.layers:
|
136
|
+
if swiglu:
|
137
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
138
|
+
if rms_norm:
|
139
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
140
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
141
|
+
|
142
|
+
|
64
143
|
def apply_liger_kernel_to_llama(
|
65
144
|
rope: bool = True,
|
66
145
|
cross_entropy: bool = False,
|
@@ -740,6 +819,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
740
819
|
"gemma": apply_liger_kernel_to_gemma,
|
741
820
|
"gemma2": apply_liger_kernel_to_gemma2,
|
742
821
|
"llama": apply_liger_kernel_to_llama,
|
822
|
+
"granite": apply_liger_kernel_to_granite,
|
743
823
|
"mllama": apply_liger_kernel_to_mllama,
|
744
824
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
745
825
|
"mistral": apply_liger_kernel_to_mistral,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.3.
|
3
|
+
Version: 0.5.3.dev20250221011057
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -126,7 +126,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
|
|
126
126
|
|
127
127
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
128
128
|
|
129
|
-
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
129
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
130
130
|
|
131
131
|
## Supercharge Your Model with Liger Kernel
|
132
132
|
|
@@ -341,6 +341,7 @@ loss.backward()
|
|
341
341
|
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
342
342
|
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
343
343
|
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
344
|
+
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
|
344
345
|
|
345
346
|
### Distillation Kernels
|
346
347
|
|
@@ -32,7 +32,7 @@ liger_kernel/ops/tvd.py,sha256=9wVCijj2vBtgiLeUHhl7hy_LAiJ3liPIYOGMSU3P1ro,6407
|
|
32
32
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
33
33
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
34
34
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
35
|
-
liger_kernel/transformers/__init__.py,sha256=
|
35
|
+
liger_kernel/transformers/__init__.py,sha256=i6GPkP5-esFBh205nF4MluNrL7KNugseGiUKdSHGW70,2172
|
36
36
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
37
37
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
38
38
|
liger_kernel/transformers/functional.py,sha256=zahXVCjA2NxcVFpAgajILIRN0GO6mrbfLPgONUkTrY8,4940
|
@@ -43,7 +43,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
|
|
43
43
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
44
44
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
45
45
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
46
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
46
|
+
liger_kernel/transformers/monkey_patch.py,sha256=kSJE1aMLN0e4jxiePUISRPcyvhWyzhOaqIwLW6rG0Zo,41191
|
47
47
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
48
48
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
49
49
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
@@ -65,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
65
65
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
66
66
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
67
67
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
68
|
-
liger_kernel_nightly-0.5.3.
|
69
|
-
liger_kernel_nightly-0.5.3.
|
70
|
-
liger_kernel_nightly-0.5.3.
|
71
|
-
liger_kernel_nightly-0.5.3.
|
72
|
-
liger_kernel_nightly-0.5.3.
|
73
|
-
liger_kernel_nightly-0.5.3.
|
68
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
69
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/METADATA,sha256=eOu7vQd8uhDyA4a1jUXnZRPAH7GL3eB4ASMYVb9oDZA,21963
|
70
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
71
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
72
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
73
|
+
liger_kernel_nightly-0.5.3.dev20250221011057.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|