liger-kernel 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,23 @@
1
+ import inspect
2
+ import logging
3
+ from functools import partial
4
+
1
5
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
2
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
3
- from liger_kernel.transformers.model.llama import lce_forward
7
+ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
8
+ from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
9
+ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
10
+ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
11
+ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
4
12
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
5
13
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
6
- from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
14
+ from liger_kernel.transformers.swiglu import (
15
+ LigerBlockSparseTop2MLP,
16
+ LigerPhi3SwiGLUMLP,
17
+ LigerSwiGLUMLP,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
7
21
 
8
22
 
9
23
  def apply_liger_kernel_to_llama(
@@ -20,7 +34,7 @@ def apply_liger_kernel_to_llama(
20
34
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
21
35
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
22
36
  fused_linear_cross_entropy (bool):
23
- Whether to apply Liger's fused lienar cross entropy loss. Default is True.
37
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
24
38
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
25
39
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
26
40
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
@@ -42,12 +56,13 @@ def apply_liger_kernel_to_llama(
42
56
  if cross_entropy:
43
57
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
44
58
  if fused_linear_cross_entropy:
45
- modeling_llama.LlamaForCausalLM.forward = lce_forward
59
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
46
60
 
47
61
 
48
62
  def apply_liger_kernel_to_mistral(
49
63
  rope: bool = True,
50
- cross_entropy: bool = True,
64
+ cross_entropy: bool = False,
65
+ fused_linear_cross_entropy: bool = True,
51
66
  rms_norm: bool = True,
52
67
  swiglu: bool = True,
53
68
  ) -> None:
@@ -57,9 +72,17 @@ def apply_liger_kernel_to_mistral(
57
72
  Args:
58
73
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
59
74
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
75
+ fused_linear_cross_entropy (bool):
76
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
77
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
78
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
79
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
60
80
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
61
81
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
62
82
  """
83
+ assert not (
84
+ cross_entropy and fused_linear_cross_entropy
85
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
63
86
 
64
87
  from transformers.models.mistral import modeling_mistral
65
88
 
@@ -69,6 +92,8 @@ def apply_liger_kernel_to_mistral(
69
92
  modeling_mistral.MistralRMSNorm = LigerRMSNorm
70
93
  if cross_entropy:
71
94
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
95
+ if fused_linear_cross_entropy:
96
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
72
97
  if swiglu:
73
98
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
74
99
 
@@ -94,7 +119,7 @@ def apply_liger_kernel_to_mixtral(
94
119
  if rope:
95
120
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
96
121
  if rms_norm:
97
- modeling_mixtral.MistralRMSNorm = LigerRMSNorm
122
+ modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
98
123
  if cross_entropy:
99
124
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
100
125
  if swiglu:
@@ -104,12 +129,13 @@ def apply_liger_kernel_to_mixtral(
104
129
  def apply_liger_kernel_to_gemma(
105
130
  rope: bool = True,
106
131
  cross_entropy: bool = True,
132
+ fused_linear_cross_entropy: bool = True,
107
133
  rms_norm: bool = True,
108
134
  geglu: bool = True,
109
135
  ) -> None:
110
136
  """
111
- Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
112
- to make GPU go burrr.
137
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma
138
+ (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
113
139
 
114
140
  Args:
115
141
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -117,14 +143,181 @@ def apply_liger_kernel_to_gemma(
117
143
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
118
144
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
119
145
  """
120
- # TODO(yundai424): add convergence test for gemma
146
+ assert not (
147
+ cross_entropy and fused_linear_cross_entropy
148
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
149
+
121
150
  from transformers.models.gemma import modeling_gemma
122
151
 
123
152
  if rope:
124
153
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
125
154
  if rms_norm:
126
- modeling_gemma.GemmaRMSNorm = LigerRMSNorm
155
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
156
+ modeling_gemma.GemmaRMSNorm = partial(
157
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
158
+ )
127
159
  if cross_entropy:
128
160
  modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
129
161
  if geglu:
130
162
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
163
+ if fused_linear_cross_entropy:
164
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
165
+
166
+
167
+ def apply_liger_kernel_to_gemma2(
168
+ rope: bool = True,
169
+ cross_entropy: bool = True,
170
+ rms_norm: bool = True,
171
+ geglu: bool = True,
172
+ ) -> None:
173
+ """
174
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma2
175
+ (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
176
+
177
+ Args:
178
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
179
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
180
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
181
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
182
+ """
183
+ from transformers.models.gemma2 import modeling_gemma2
184
+
185
+ if rope:
186
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
187
+ if rms_norm:
188
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
189
+ modeling_gemma2.Gemma2RMSNorm = partial(
190
+ LigerRMSNorm, offset=1.0, init_fn="zeros"
191
+ )
192
+ if cross_entropy:
193
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
194
+ if geglu:
195
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
196
+
197
+
198
+ def apply_liger_kernel_to_qwen2(
199
+ rope: bool = True,
200
+ cross_entropy: bool = False,
201
+ fused_linear_cross_entropy: bool = True,
202
+ rms_norm: bool = True,
203
+ swiglu: bool = True,
204
+ ) -> None:
205
+ """
206
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
207
+
208
+ Args:
209
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
210
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
211
+ fused_linear_cross_entropy (bool):
212
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
213
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
214
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
215
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
216
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
217
+ """
218
+ assert not (
219
+ cross_entropy and fused_linear_cross_entropy
220
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
221
+
222
+ from transformers.models.qwen2 import modeling_qwen2
223
+
224
+ if rope:
225
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
226
+ if rms_norm:
227
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
228
+ if cross_entropy:
229
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
230
+ if fused_linear_cross_entropy:
231
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
232
+ if swiglu:
233
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
234
+
235
+
236
+ def apply_liger_kernel_to_phi3(
237
+ rope: bool = True,
238
+ cross_entropy: bool = False,
239
+ fused_linear_cross_entropy: bool = True,
240
+ rms_norm: bool = True,
241
+ swiglu: bool = True,
242
+ ) -> None:
243
+ """
244
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
245
+
246
+ Args:
247
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
248
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
249
+ fused_linear_cross_entropy (bool):
250
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
251
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
252
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
253
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
254
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
255
+ """
256
+ assert not (
257
+ cross_entropy and fused_linear_cross_entropy
258
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
259
+
260
+ from transformers.models.phi3 import modeling_phi3
261
+
262
+ if rope:
263
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
264
+ if rms_norm:
265
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
266
+ if swiglu:
267
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
268
+ if cross_entropy:
269
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
270
+ if fused_linear_cross_entropy:
271
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
272
+
273
+
274
+ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
275
+ MODEL_TYPE_TO_APPLY_LIGER_FN = {
276
+ "gemma": apply_liger_kernel_to_gemma,
277
+ "gemma2": apply_liger_kernel_to_gemma2,
278
+ "llama": apply_liger_kernel_to_llama,
279
+ "mistral": apply_liger_kernel_to_mistral,
280
+ "mixtral": apply_liger_kernel_to_mixtral,
281
+ "qwen2": apply_liger_kernel_to_qwen2,
282
+ "phi3": apply_liger_kernel_to_phi3,
283
+ }
284
+
285
+
286
+ def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
287
+ """
288
+ Applies Liger kernels based on the specified model type. The custom
289
+ kernels for the specified model type will be applied with the provided
290
+ keyword arguments, otherwise the default configuration will be used.
291
+
292
+ Args:
293
+ - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
294
+ and specified in the model's config.json
295
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
296
+ """
297
+
298
+ if not model_type:
299
+ logger.info("Model type was not provided. No Liger kernels will be applied.")
300
+ return
301
+
302
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
303
+ logger.info(
304
+ f"There are currently no Liger kernels supported for model type: {model_type}."
305
+ )
306
+ return
307
+
308
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
309
+ apply_fn_signature = inspect.signature(apply_fn)
310
+
311
+ # Filter out the keyword arguments that are not supported by the apply function
312
+ applicable_kwargs = {
313
+ key: value
314
+ for key, value in kwargs.items()
315
+ if key in apply_fn_signature.parameters
316
+ }
317
+
318
+ logger.info(
319
+ f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
320
+ )
321
+
322
+ # Apply the default combination of liger kernels available for the model
323
+ apply_fn(**applicable_kwargs)
@@ -5,12 +5,28 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
5
5
 
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
- def __init__(self, hidden_size, eps=1e-6):
8
+ def __init__(
9
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
10
+ ):
9
11
  super().__init__()
10
- self.weight = nn.Parameter(torch.ones(hidden_size))
11
- self.variance_epsilon = eps
12
+ assert init_fn in [
13
+ "ones",
14
+ "zeros",
15
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
16
+ self.weight = nn.Parameter(
17
+ torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
+ )
19
+ self.variance_epsilon, self.offset, self.casting_mode = (
20
+ eps,
21
+ offset,
22
+ casting_mode,
23
+ )
12
24
 
13
25
  def forward(self, hidden_states):
14
26
  return LigerRMSNormFunction.apply(
15
- hidden_states, self.weight, self.variance_epsilon
27
+ hidden_states,
28
+ self.weight,
29
+ self.variance_epsilon,
30
+ self.offset,
31
+ self.casting_mode,
16
32
  )
@@ -38,3 +38,27 @@ class LigerBlockSparseTop2MLP(nn.Module):
38
38
  def forward(self, x):
39
39
 
40
40
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
+
42
+
43
+ class LigerPhi3SwiGLUMLP(nn.Module):
44
+ """
45
+ Patch Phi3MLP to use LigerSiLUMulFunction
46
+ https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.config = config
52
+ self.hidden_size = config.hidden_size
53
+ self.intermediate_size = config.intermediate_size
54
+ self.gate_up_proj = nn.Linear(
55
+ self.hidden_size, 2 * self.intermediate_size, bias=False
56
+ )
57
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
+ if config.hidden_act not in ["silu", "swish"]:
59
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
60
+
61
+ def forward(self, x):
62
+ up_states = self.gate_up_proj(x)
63
+ gate, up_states = up_states.chunk(2, dim=-1)
64
+ return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
@@ -1,45 +1,2 @@
1
- import logging
2
-
3
- from liger_kernel.transformers.monkey_patch import (
4
- apply_liger_kernel_to_gemma,
5
- apply_liger_kernel_to_llama,
6
- apply_liger_kernel_to_mistral,
7
- apply_liger_kernel_to_mixtral,
8
- )
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
13
- MODEL_TYPE_TO_APPLY_LIGER_FN = {
14
- "gemma": apply_liger_kernel_to_gemma,
15
- "llama": apply_liger_kernel_to_llama,
16
- "mistral": apply_liger_kernel_to_mistral,
17
- "mixtral": apply_liger_kernel_to_mixtral,
18
- }
19
-
20
-
21
- def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
22
- """
23
- Applies Liger kernels based on the specified model type. The custom
24
- kernels for the specified model type will be applied with the provided
25
- keyword arguments, otherwise the default configuration will be used.
26
-
27
- Args:
28
- - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
29
- and specified in the model's config.json
30
- - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
31
- """
32
-
33
- if not model_type:
34
- logger.info("Model type was not provided. No Liger kernels will be applied.")
35
- return
36
-
37
- if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
38
- logger.info(
39
- f"There are currently no Liger kernels supported for model type: {model_type}."
40
- )
41
- return
42
-
43
- logger.info(f"Applying Liger kernels for model type: {model_type}.")
44
- # Apply the default combination of liger kernels available for the model
45
- MODEL_TYPE_TO_APPLY_LIGER_FN[model_type](**kwargs)
1
+ # To not break HF Trainer integration
2
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger-kernel
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  Home-page: https://github.com/linkedin/Liger-Kernel
6
6
  License: BSD-2-Clause
@@ -21,22 +21,35 @@ License-File: LICENSE
21
21
  License-File: NOTICE
22
22
  Requires-Dist: torch>=2.1.2
23
23
  Requires-Dist: triton>=2.3.0
24
- Requires-Dist: transformers>=4.40.1
24
+ Requires-Dist: transformers>=4.42.0
25
25
  Provides-Extra: dev
26
26
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
27
27
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
28
28
  Requires-Dist: black>=24.4.2; extra == "dev"
29
29
  Requires-Dist: isort>=5.13.2; extra == "dev"
30
- Requires-Dist: pre-commit>=3.7.1; extra == "dev"
31
- Requires-Dist: torch-tb-profiler>=0.4.1; extra == "dev"
30
+ Requires-Dist: pytest>=7.1.2; extra == "dev"
31
+ Requires-Dist: datasets>=2.19.2; extra == "dev"
32
32
 
33
33
  # Liger Kernel: Efficient Triton Kernels for LLM Training
34
34
 
35
+
36
+
35
37
  [![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
38
+ [![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)
36
39
 
40
+ <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
37
41
 
38
42
  [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing)
39
43
 
44
+ <details>
45
+ <summary>Latest News 🔥</summary>
46
+
47
+ - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://discord.gg/6CNeDAjq?event=1273323969788772455)
48
+ - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
49
+
50
+ </details>
51
+
52
+
40
53
  **Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
41
54
 
42
55
  ## Supercharge Your Model with Liger Kernel
@@ -52,8 +65,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
52
65
  | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
53
66
 
54
67
  > **Note:**
55
- > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
56
- > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
68
+ > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
69
+ > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
57
70
 
58
71
  ## Examples
59
72
 
@@ -91,12 +104,15 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
91
104
 
92
105
  - `torch >= 2.1.2`
93
106
  - `triton >= 2.3.0`
94
- - `transformers >= 4.40.1`
107
+ - `transformers >= 4.42.0`
108
+
109
+ > **Note:**
110
+ > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
95
111
 
96
112
  To install the stable version:
97
113
 
98
114
  ```bash
99
- $ pip install liger-kernel
115
+ $ pip install liger-kernel
100
116
  ```
101
117
 
102
118
  To install the nightly version:
@@ -105,9 +121,30 @@ To install the nightly version:
105
121
  $ pip install liger-kernel-nightly
106
122
  ```
107
123
 
124
+ To install from source:
125
+
126
+ ```bash
127
+ git clone https://github.com/linkedin/Liger-Kernel.git
128
+ cd Liger-Kernel
129
+ pip install -e .
130
+ ```
108
131
  ## Getting Started
109
132
 
110
- ### 1. Patch Existing Hugging Face Models
133
+ There are a couple ways to apply Liger kernels, depending on the level of customization required.
134
+
135
+ ### 1. Use AutoLigerKernelForCausalLM
136
+
137
+ Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
138
+
139
+ ```python
140
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
141
+
142
+ # This AutoModel wrapper class automatically monkey-patches the
143
+ # model with the optimized Liger kernels if the model is supported.
144
+ model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
145
+ ```
146
+
147
+ ### 2. Apply Model-Specific Patching APIs
111
148
 
112
149
  Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
113
150
 
@@ -115,13 +152,22 @@ Using the [patching APIs](#patching), you can swap Hugging Face models with opti
115
152
  import transformers
116
153
  from liger_kernel.transformers import apply_liger_kernel_to_llama
117
154
 
118
- model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
155
+ model = transformers.AutoModelForCausalLM("path/to/llama/model")
119
156
 
120
157
  # Adding this line automatically monkey-patches the model with the optimized Liger kernels
121
- apply_liger_kernel_to_llama()
158
+ apply_liger_kernel_to_llama()
159
+
160
+ # You could alternatively specify exactly which kernels are applied
161
+ apply_liger_kernel_to_llama(
162
+ rope=True,
163
+ swiglu=True,
164
+ cross_entropy=True,
165
+ fused_linear_cross_entropy=False,
166
+ rms_norm=False
167
+ )
122
168
  ```
123
169
 
124
- ### 2. Compose Your Own Model
170
+ ### 3. Compose Your Own Model
125
171
 
126
172
  You can take individual [kernels](#kernels) to compose your models.
127
173
 
@@ -161,14 +207,26 @@ loss.backward()
161
207
 
162
208
  ## APIs
163
209
 
210
+ ### AutoModel
211
+
212
+ | **AutoModel Variant** | **API** |
213
+ |-----------|---------|
214
+ | AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` |
215
+
216
+
164
217
  ### Patching
165
218
 
166
219
  | **Model** | **API** | **Supported Operations** |
167
220
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
168
- | LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
169
- | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
221
+ | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
222
+ | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
170
223
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
171
- | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
224
+ | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
225
+ | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
226
+ | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
227
+ | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
228
+
229
+
172
230
 
173
231
  ### Kernels
174
232
 
@@ -182,11 +240,11 @@ loss.backward()
182
240
  | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
183
241
 
184
242
  - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
185
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
186
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
243
+ - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
244
+ - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
187
245
  $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
188
246
  , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
189
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
247
+ - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
190
248
  $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
191
249
  , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
192
250
  - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
@@ -195,12 +253,12 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
195
253
 
196
254
 
197
255
  <!-- TODO: be more specific about batch size -->
198
- > **Note:**
256
+ > **Note:**
199
257
  > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
200
258
 
201
259
  ## Note on ML Compiler
202
260
 
203
- ### 1. Torch Compile
261
+ ### Torch Compile
204
262
 
205
263
  Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
206
264
 
@@ -209,20 +267,17 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
209
267
  | Torch Compile | 3780 | 66.4 |
210
268
  | Torch Compile + Liger Kernel | 3702 | 31.0 |
211
269
 
212
- > **Note:**
270
+ > **Note:**
213
271
  > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
214
272
  > 2. Tested on torch `2.5.0.dev20240731+cu118`
215
273
 
216
- ### 2. Lightning Thunder
217
-
218
- *WIP*
219
-
220
274
  ## Contributing
221
275
 
222
276
  [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
223
277
 
224
278
  ## Acknowledgement
225
279
 
280
+ - [@claire_yishan](https://twitter.com/claire_yishan) for the LOGO design
226
281
  - [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
227
282
  - [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) by Andrej Karpathy for convergence testing
228
283
  - [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) for lm_head + cross entropy inspiration
@@ -232,6 +287,10 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
232
287
 
233
288
  [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
234
289
 
290
+ ## Contact
291
+
292
+ - For collaboration, please send email to byhsu@linkedin.com
293
+
235
294
  ## Cite this work
236
295
 
237
296
  Biblatex entry:
@@ -243,3 +302,6 @@ Biblatex entry:
243
302
  year = {2024}
244
303
  }
245
304
  ```
305
+
306
+ ## Star History
307
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
@@ -0,0 +1,33 @@
1
+ liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
2
+ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ liger_kernel/ops/cross_entropy.py,sha256=6-jI03Yw_B8gHLmqxTOMpbFMRJhNNaE4DKpmowWYrTE,9177
4
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=RCStJoBsgiAa03ZupWhZEHjnAbm52iNiMDsGs6VDtnY,8385
5
+ liger_kernel/ops/geglu.py,sha256=f8j9qnZgqvXFDFStZ5WbtRPDuNo9FBdVdXX7ufMHNpE,4052
6
+ liger_kernel/ops/rms_norm.py,sha256=B0FDElKiTygv1CdG3NzbeOeC7nj_-2vmNIg4RHistHI,9517
7
+ liger_kernel/ops/rope.py,sha256=8TOkpjmeekQEp1x6OAXTAWwoTTcEhNHSk9GnjuhW-Cw,8570
8
+ liger_kernel/ops/swiglu.py,sha256=MRbSIXsBLqlFr9ZdtuFqSjLJJ-716URmQIhxQ57GGEw,2915
9
+ liger_kernel/ops/utils.py,sha256=vsFIywd8LQlVPRA3RPZOm5HyN8c0cS4NFEEnwjNw-MI,1427
10
+ liger_kernel/transformers/__init__.py,sha256=Um9ZRvT289MVFoGmliSva3q3YLRDqwYmBLxIj0rD9nI,403
11
+ liger_kernel/transformers/auto_model.py,sha256=WQyaORi2zPIWTLhuAWCRPIzyHd5T4my4yGHQrt1-uBA,1247
12
+ liger_kernel/transformers/cross_entropy.py,sha256=G-L4EaUYVc25NKZ2jrlaG-d5YUvDqJdUlawPN7K1d1g,389
13
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=X6ni--b5F2GRxL46PrgjsvQuWEcKp3Z2cELdMRR0oyY,518
14
+ liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
15
+ liger_kernel/transformers/monkey_patch.py,sha256=vbYmT7povRJQA49Ra7GGD3hMN2lVTXmsxO9R8lfQSoI,13022
16
+ liger_kernel/transformers/rms_norm.py,sha256=YxBSn2bIfh24De8Xb7QhhmdG3taauj_qJNvEjlazonU,912
17
+ liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
18
+ liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
19
+ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
20
+ liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ liger_kernel/transformers/model/gemma.py,sha256=EcdkGbSj_qroTDFl0Sc_HLyDyY0xcDhwrgkM_wkXnw8,4987
22
+ liger_kernel/transformers/model/llama.py,sha256=6McXLi_Bt35WuxaJ_0CzEnOtayHXiPw5vjiDsaQKdJU,5323
23
+ liger_kernel/transformers/model/mistral.py,sha256=_MQJrDntlxBO5cJwgTjr2rk2nNd5FAXVnzcTg_PEekQ,5079
24
+ liger_kernel/transformers/model/phi3.py,sha256=zmjOsVV5TjKJ0U2dCm6W-8WCx1toKoh2Wm2PZu3XOIw,4927
25
+ liger_kernel/transformers/model/qwen2.py,sha256=Va4uiZaVzCG2V7XKDfHjZyYTre5vPQM02j83jnnhono,4873
26
+ liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
27
+ liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
28
+ liger_kernel-0.2.0.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
29
+ liger_kernel-0.2.0.dist-info/METADATA,sha256=eAzgrCGOn_jpoiH8VomABxzt7CThKfBqOIKH5Qmfm3w,17049
30
+ liger_kernel-0.2.0.dist-info/NOTICE,sha256=BXkXY9aWvEy_7MAB57zDu1z8uMYT1i1l9B6EpHuBa8s,173
31
+ liger_kernel-0.2.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
32
+ liger_kernel-0.2.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
33
+ liger_kernel-0.2.0.dist-info/RECORD,,