liger-kernel 0.1.1__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/PKG-INFO +84 -22
  2. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/README.md +83 -21
  3. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/setup.py +15 -15
  4. liger_kernel-0.2.1/src/liger_kernel/env_report.py +46 -0
  5. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/cross_entropy.py +5 -5
  6. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +50 -21
  7. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/geglu.py +6 -1
  8. liger_kernel-0.2.1/src/liger_kernel/ops/rms_norm.py +307 -0
  9. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/rope.py +3 -3
  10. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/__init__.py +6 -0
  11. liger_kernel-0.2.1/src/liger_kernel/transformers/auto_model.py +33 -0
  12. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +2 -2
  13. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/geglu.py +4 -2
  14. liger_kernel-0.2.1/src/liger_kernel/transformers/model/gemma.py +138 -0
  15. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/model/llama.py +1 -1
  16. liger_kernel-0.2.1/src/liger_kernel/transformers/model/mistral.py +138 -0
  17. liger_kernel-0.2.1/src/liger_kernel/transformers/model/phi3.py +136 -0
  18. liger_kernel-0.2.1/src/liger_kernel/transformers/model/qwen2.py +135 -0
  19. liger_kernel-0.2.1/src/liger_kernel/transformers/monkey_patch.py +323 -0
  20. liger_kernel-0.2.1/src/liger_kernel/transformers/rms_norm.py +32 -0
  21. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/swiglu.py +24 -0
  22. liger_kernel-0.2.1/src/liger_kernel/transformers/trainer_integration.py +2 -0
  23. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel.egg-info/PKG-INFO +84 -22
  24. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel.egg-info/SOURCES.txt +6 -0
  25. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel.egg-info/requires.txt +3 -3
  26. liger_kernel-0.1.1/src/liger_kernel/ops/rms_norm.py +0 -185
  27. liger_kernel-0.1.1/src/liger_kernel/transformers/monkey_patch.py +0 -130
  28. liger_kernel-0.1.1/src/liger_kernel/transformers/rms_norm.py +0 -16
  29. liger_kernel-0.1.1/src/liger_kernel/transformers/trainer_integration.py +0 -45
  30. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/LICENSE +0 -0
  31. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/NOTICE +0 -0
  32. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/setup.cfg +0 -0
  33. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/__init__.py +0 -0
  34. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/swiglu.py +0 -0
  35. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/ops/utils.py +0 -0
  36. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
  38. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/transformers/rope.py +0 -0
  39. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/triton/__init__.py +0 -0
  40. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
  41. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  42. {liger_kernel-0.1.1 → liger_kernel-0.2.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.1.1
3
+ Version: 0.2.1
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  Home-page: https://github.com/linkedin/Liger-Kernel
6
6
  License: BSD-2-Clause
@@ -23,11 +23,24 @@ License-File: NOTICE
23
23
 
24
24
  # Liger Kernel: Efficient Triton Kernels for LLM Training
25
25
 
26
+
27
+
26
28
  [![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
29
+ [![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)
27
30
 
31
+ <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
28
32
 
29
33
  [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing)
30
34
 
35
+ <details>
36
+ <summary>Latest News 🔥</summary>
37
+
38
+ - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://discord.gg/6CNeDAjq?event=1273323969788772455)
39
+ - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
40
+
41
+ </details>
42
+
43
+
31
44
  **Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
32
45
 
33
46
  ## Supercharge Your Model with Liger Kernel
@@ -43,8 +56,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
43
56
  | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
44
57
 
45
58
  > **Note:**
46
- > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
47
- > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
59
+ > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
60
+ > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
48
61
 
49
62
  ## Examples
50
63
 
@@ -82,12 +95,15 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
82
95
 
83
96
  - `torch >= 2.1.2`
84
97
  - `triton >= 2.3.0`
85
- - `transformers >= 4.40.1`
98
+ - `transformers >= 4.42.0`
99
+
100
+ > **Note:**
101
+ > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
86
102
 
87
103
  To install the stable version:
88
104
 
89
105
  ```bash
90
- $ pip install liger-kernel
106
+ $ pip install liger-kernel
91
107
  ```
92
108
 
93
109
  To install the nightly version:
@@ -96,9 +112,30 @@ To install the nightly version:
96
112
  $ pip install liger-kernel-nightly
97
113
  ```
98
114
 
115
+ To install from source:
116
+
117
+ ```bash
118
+ git clone https://github.com/linkedin/Liger-Kernel.git
119
+ cd Liger-Kernel
120
+ pip install -e .
121
+ ```
99
122
  ## Getting Started
100
123
 
101
- ### 1. Patch Existing Hugging Face Models
124
+ There are a couple ways to apply Liger kernels, depending on the level of customization required.
125
+
126
+ ### 1. Use AutoLigerKernelForCausalLM
127
+
128
+ Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
129
+
130
+ ```python
131
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
132
+
133
+ # This AutoModel wrapper class automatically monkey-patches the
134
+ # model with the optimized Liger kernels if the model is supported.
135
+ model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
136
+ ```
137
+
138
+ ### 2. Apply Model-Specific Patching APIs
102
139
 
103
140
  Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
104
141
 
@@ -106,13 +143,22 @@ Using the [patching APIs](#patching), you can swap Hugging Face models with opti
106
143
  import transformers
107
144
  from liger_kernel.transformers import apply_liger_kernel_to_llama
108
145
 
109
- model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
146
+ model = transformers.AutoModelForCausalLM("path/to/llama/model")
110
147
 
111
148
  # Adding this line automatically monkey-patches the model with the optimized Liger kernels
112
- apply_liger_kernel_to_llama()
149
+ apply_liger_kernel_to_llama()
150
+
151
+ # You could alternatively specify exactly which kernels are applied
152
+ apply_liger_kernel_to_llama(
153
+ rope=True,
154
+ swiglu=True,
155
+ cross_entropy=True,
156
+ fused_linear_cross_entropy=False,
157
+ rms_norm=False
158
+ )
113
159
  ```
114
160
 
115
- ### 2. Compose Your Own Model
161
+ ### 3. Compose Your Own Model
116
162
 
117
163
  You can take individual [kernels](#kernels) to compose your models.
118
164
 
@@ -152,14 +198,26 @@ loss.backward()
152
198
 
153
199
  ## APIs
154
200
 
201
+ ### AutoModel
202
+
203
+ | **AutoModel Variant** | **API** |
204
+ |-----------|---------|
205
+ | AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` |
206
+
207
+
155
208
  ### Patching
156
209
 
157
210
  | **Model** | **API** | **Supported Operations** |
158
211
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
159
- | LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
160
- | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
212
+ | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
213
+ | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
161
214
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
162
- | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
215
+ | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
216
+ | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
217
+ | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
218
+ | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
219
+
220
+
163
221
 
164
222
  ### Kernels
165
223
 
@@ -173,11 +231,11 @@ loss.backward()
173
231
  | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
174
232
 
175
233
  - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
176
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
177
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
234
+ - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
235
+ - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
178
236
  $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
179
237
  , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
180
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
238
+ - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
181
239
  $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
182
240
  , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
183
241
  - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
@@ -186,12 +244,12 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
186
244
 
187
245
 
188
246
  <!-- TODO: be more specific about batch size -->
189
- > **Note:**
247
+ > **Note:**
190
248
  > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
191
249
 
192
250
  ## Note on ML Compiler
193
251
 
194
- ### 1. Torch Compile
252
+ ### Torch Compile
195
253
 
196
254
  Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
197
255
 
@@ -200,20 +258,17 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
200
258
  | Torch Compile | 3780 | 66.4 |
201
259
  | Torch Compile + Liger Kernel | 3702 | 31.0 |
202
260
 
203
- > **Note:**
261
+ > **Note:**
204
262
  > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
205
263
  > 2. Tested on torch `2.5.0.dev20240731+cu118`
206
264
 
207
- ### 2. Lightning Thunder
208
-
209
- *WIP*
210
-
211
265
  ## Contributing
212
266
 
213
267
  [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
214
268
 
215
269
  ## Acknowledgement
216
270
 
271
+ - [@claire_yishan](https://twitter.com/claire_yishan) for the LOGO design
217
272
  - [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
218
273
  - [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) by Andrej Karpathy for convergence testing
219
274
  - [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) for lm_head + cross entropy inspiration
@@ -223,6 +278,10 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
223
278
 
224
279
  [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
225
280
 
281
+ ## Contact
282
+
283
+ - For collaboration, please send email to byhsu@linkedin.com
284
+
226
285
  ## Cite this work
227
286
 
228
287
  Biblatex entry:
@@ -234,3 +293,6 @@ Biblatex entry:
234
293
  year = {2024}
235
294
  }
236
295
  ```
296
+
297
+ ## Star History
298
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
@@ -1,10 +1,23 @@
1
1
  # Liger Kernel: Efficient Triton Kernels for LLM Training
2
2
 
3
+
4
+
3
5
  [![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
6
+ [![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)
4
7
 
8
+ <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
5
9
 
6
10
  [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing)
7
11
 
12
+ <details>
13
+ <summary>Latest News 🔥</summary>
14
+
15
+ - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://discord.gg/6CNeDAjq?event=1273323969788772455)
16
+ - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
17
+
18
+ </details>
19
+
20
+
8
21
  **Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
9
22
 
10
23
  ## Supercharge Your Model with Liger Kernel
@@ -20,8 +33,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
20
33
  | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
21
34
 
22
35
  > **Note:**
23
- > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
24
- > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
36
+ > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
37
+ > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
25
38
 
26
39
  ## Examples
27
40
 
@@ -59,12 +72,15 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
59
72
 
60
73
  - `torch >= 2.1.2`
61
74
  - `triton >= 2.3.0`
62
- - `transformers >= 4.40.1`
75
+ - `transformers >= 4.42.0`
76
+
77
+ > **Note:**
78
+ > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
63
79
 
64
80
  To install the stable version:
65
81
 
66
82
  ```bash
67
- $ pip install liger-kernel
83
+ $ pip install liger-kernel
68
84
  ```
69
85
 
70
86
  To install the nightly version:
@@ -73,9 +89,30 @@ To install the nightly version:
73
89
  $ pip install liger-kernel-nightly
74
90
  ```
75
91
 
92
+ To install from source:
93
+
94
+ ```bash
95
+ git clone https://github.com/linkedin/Liger-Kernel.git
96
+ cd Liger-Kernel
97
+ pip install -e .
98
+ ```
76
99
  ## Getting Started
77
100
 
78
- ### 1. Patch Existing Hugging Face Models
101
+ There are a couple ways to apply Liger kernels, depending on the level of customization required.
102
+
103
+ ### 1. Use AutoLigerKernelForCausalLM
104
+
105
+ Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
106
+
107
+ ```python
108
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
109
+
110
+ # This AutoModel wrapper class automatically monkey-patches the
111
+ # model with the optimized Liger kernels if the model is supported.
112
+ model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
113
+ ```
114
+
115
+ ### 2. Apply Model-Specific Patching APIs
79
116
 
80
117
  Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
81
118
 
@@ -83,13 +120,22 @@ Using the [patching APIs](#patching), you can swap Hugging Face models with opti
83
120
  import transformers
84
121
  from liger_kernel.transformers import apply_liger_kernel_to_llama
85
122
 
86
- model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
123
+ model = transformers.AutoModelForCausalLM("path/to/llama/model")
87
124
 
88
125
  # Adding this line automatically monkey-patches the model with the optimized Liger kernels
89
- apply_liger_kernel_to_llama()
126
+ apply_liger_kernel_to_llama()
127
+
128
+ # You could alternatively specify exactly which kernels are applied
129
+ apply_liger_kernel_to_llama(
130
+ rope=True,
131
+ swiglu=True,
132
+ cross_entropy=True,
133
+ fused_linear_cross_entropy=False,
134
+ rms_norm=False
135
+ )
90
136
  ```
91
137
 
92
- ### 2. Compose Your Own Model
138
+ ### 3. Compose Your Own Model
93
139
 
94
140
  You can take individual [kernels](#kernels) to compose your models.
95
141
 
@@ -129,14 +175,26 @@ loss.backward()
129
175
 
130
176
  ## APIs
131
177
 
178
+ ### AutoModel
179
+
180
+ | **AutoModel Variant** | **API** |
181
+ |-----------|---------|
182
+ | AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` |
183
+
184
+
132
185
  ### Patching
133
186
 
134
187
  | **Model** | **API** | **Supported Operations** |
135
188
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
136
- | LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
137
- | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
189
+ | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
190
+ | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
138
191
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
139
- | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
192
+ | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
193
+ | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
194
+ | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
195
+ | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
196
+
197
+
140
198
 
141
199
  ### Kernels
142
200
 
@@ -150,11 +208,11 @@ loss.backward()
150
208
  | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
151
209
 
152
210
  - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
153
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
154
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
211
+ - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
212
+ - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
155
213
  $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
156
214
  , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
157
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
215
+ - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
158
216
  $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
159
217
  , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
160
218
  - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
@@ -163,12 +221,12 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
163
221
 
164
222
 
165
223
  <!-- TODO: be more specific about batch size -->
166
- > **Note:**
224
+ > **Note:**
167
225
  > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
168
226
 
169
227
  ## Note on ML Compiler
170
228
 
171
- ### 1. Torch Compile
229
+ ### Torch Compile
172
230
 
173
231
  Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
174
232
 
@@ -177,20 +235,17 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
177
235
  | Torch Compile | 3780 | 66.4 |
178
236
  | Torch Compile + Liger Kernel | 3702 | 31.0 |
179
237
 
180
- > **Note:**
238
+ > **Note:**
181
239
  > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
182
240
  > 2. Tested on torch `2.5.0.dev20240731+cu118`
183
241
 
184
- ### 2. Lightning Thunder
185
-
186
- *WIP*
187
-
188
242
  ## Contributing
189
243
 
190
244
  [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
191
245
 
192
246
  ## Acknowledgement
193
247
 
248
+ - [@claire_yishan](https://twitter.com/claire_yishan) for the LOGO design
194
249
  - [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
195
250
  - [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) by Andrej Karpathy for convergence testing
196
251
  - [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) for lm_head + cross entropy inspiration
@@ -200,6 +255,10 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
200
255
 
201
256
  [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
202
257
 
258
+ ## Contact
259
+
260
+ - For collaboration, please send email to byhsu@linkedin.com
261
+
203
262
  ## Cite this work
204
263
 
205
264
  Biblatex entry:
@@ -211,3 +270,6 @@ Biblatex entry:
211
270
  year = {2024}
212
271
  }
213
272
  ```
273
+
274
+ ## Star History
275
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
@@ -1,6 +1,6 @@
1
1
  from setuptools import find_namespace_packages, setup
2
2
 
3
- __version__ = "0.1.1"
3
+ __version__ = "0.2.1"
4
4
 
5
5
  setup(
6
6
  name="liger_kernel",
@@ -13,24 +13,24 @@ setup(
13
13
  package_dir={"": "src"},
14
14
  packages=find_namespace_packages(where="src"),
15
15
  classifiers=[
16
- 'Development Status :: 4 - Beta',
17
- 'Intended Audience :: Developers',
18
- 'Intended Audience :: Science/Research',
19
- 'Intended Audience :: Education',
20
- 'License :: OSI Approved :: BSD License',
21
- 'Programming Language :: Python :: 3',
22
- 'Programming Language :: Python :: 3.8',
23
- 'Programming Language :: Python :: 3.9',
24
- 'Programming Language :: Python :: 3.10',
25
- 'Topic :: Software Development :: Libraries',
26
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
16
+ "Development Status :: 4 - Beta",
17
+ "Intended Audience :: Developers",
18
+ "Intended Audience :: Science/Research",
19
+ "Intended Audience :: Education",
20
+ "License :: OSI Approved :: BSD License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.8",
23
+ "Programming Language :: Python :: 3.9",
24
+ "Programming Language :: Python :: 3.10",
25
+ "Topic :: Software Development :: Libraries",
26
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
27
27
  ],
28
28
  keywords="triton,kernels,LLM training,deep learning,Hugging Face,PyTorch,GPU optimization",
29
29
  include_package_data=True,
30
30
  install_requires=[
31
31
  "torch>=2.1.2",
32
32
  "triton>=2.3.0",
33
- "transformers>=4.40.1",
33
+ "transformers>=4.42.0",
34
34
  ],
35
35
  extras_require={
36
36
  "dev": [
@@ -38,8 +38,8 @@ setup(
38
38
  "flake8>=4.0.1.1",
39
39
  "black>=24.4.2",
40
40
  "isort>=5.13.2",
41
- "pre-commit>=3.7.1",
42
- "torch-tb-profiler>=0.4.1",
41
+ "pytest>=7.1.2",
42
+ "datasets>=2.19.2",
43
43
  ]
44
44
  },
45
45
  )
@@ -0,0 +1,46 @@
1
+ import platform
2
+ import sys
3
+
4
+
5
+ def print_env_report():
6
+ """
7
+ Prints a report of the environment. Useful for debugging and reproducibility.
8
+ Usage:
9
+ ```
10
+ python -m liger_kernel.env_report
11
+ ```
12
+ """
13
+ print("Environment Report:")
14
+ print("-------------------")
15
+ print(f"Operating System: {platform.platform()}")
16
+ print(f"Python version: {sys.version.split()[0]}")
17
+
18
+ try:
19
+ import torch
20
+
21
+ print(f"PyTorch version: {torch.__version__}")
22
+ cuda_version = (
23
+ torch.version.cuda if torch.cuda.is_available() else "Not available"
24
+ )
25
+ print(f"CUDA version: {cuda_version}")
26
+ except ImportError:
27
+ print("PyTorch: Not installed")
28
+ print("CUDA version: Unable to query")
29
+
30
+ try:
31
+ import triton
32
+
33
+ print(f"Triton version: {triton.__version__}")
34
+ except ImportError:
35
+ print("Triton: Not installed")
36
+
37
+ try:
38
+ import transformers
39
+
40
+ print(f"Transformers version: {transformers.__version__}")
41
+ except ImportError:
42
+ print("Transformers: Not installed")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ print_env_report()
@@ -56,7 +56,7 @@ def liger_cross_entropy_kernel(
56
56
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
57
57
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
58
58
 
59
- # 3. [Oneline softmax] first pass: find max + sum
59
+ # 3. [Online softmax] first pass: find max + sum
60
60
  m = float("-inf") # m is the max value. use the notation from the paper
61
61
  d = 0.0 # d is the sum. use the notation from the paper
62
62
  ori_X_y = tl.load(
@@ -73,10 +73,10 @@ def liger_cross_entropy_kernel(
73
73
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
74
74
  m = m_new
75
75
 
76
- # 4. [Oneline softmax] second pass: calculate the gradients
76
+ # 4. [Online softmax] second pass: calculate the gradients
77
77
  # dx_y = (softmax(x_y) - 1) / N
78
78
  # dx_i = softmax(x_i) / N, i != y
79
- # N is the number of non ingored elements in the batch
79
+ # N is the number of non ignored elements in the batch
80
80
  for i in range(0, n_cols, BLOCK_SIZE):
81
81
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
82
82
  X_block = tl.load(
@@ -86,7 +86,7 @@ def liger_cross_entropy_kernel(
86
86
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
87
87
 
88
88
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
89
- # ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
89
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
90
90
  tl.debug_barrier()
91
91
 
92
92
  # 5. Calculate the loss
@@ -196,7 +196,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
196
196
  ignore_index=ignore_index,
197
197
  BLOCK_SIZE=BLOCK_SIZE,
198
198
  # TODO: 32 seems to give the best performance
199
- # Performance is quite sentitive to num_warps
199
+ # Performance is quite sensitive to num_warps
200
200
  num_warps=32,
201
201
  )
202
202