liger-kernel 0.4.2__tar.gz → 0.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel-0.4.2/src/liger_kernel.egg-info → liger_kernel-0.5.1}/PKG-INFO +69 -45
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/README.md +59 -44
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/pyproject.toml +17 -1
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/__init__.py +4 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/cpo_loss.py +107 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/dpo_loss.py +135 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/functional.py +9 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/orpo_loss.py +113 -0
- liger_kernel-0.5.1/src/liger_kernel/chunked_loss/simpo_loss.py +115 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/env_report.py +22 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/cross_entropy.py +17 -10
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/fused_linear_jsd.py +1 -1
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/jsd.py +19 -10
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/layer_norm.py +6 -1
- liger_kernel-0.5.1/src/liger_kernel/ops/qwen2vl_mrope.py +238 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/rms_norm.py +6 -1
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/utils.py +5 -2
- liger_kernel-0.5.1/src/liger_kernel/transformers/functional.py +173 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/fused_linear_jsd.py +1 -4
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/jsd.py +1 -4
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/monkey_patch.py +6 -4
- liger_kernel-0.5.1/src/liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel-0.5.1/src/liger_kernel/transformers/trainer/__init__.py +6 -0
- liger_kernel-0.5.1/src/liger_kernel/transformers/trainer/orpo_trainer.py +169 -0
- liger_kernel-0.5.1/src/liger_kernel/utils.py +13 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1/src/liger_kernel.egg-info}/PKG-INFO +69 -45
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/SOURCES.txt +10 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/requires.txt +11 -0
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/dpo_loss.py +0 -57
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -206
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/orpo_loss.py +0 -63
- liger_kernel-0.4.2/src/liger_kernel/transformers/functional.py +0 -56
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/LICENSE +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/NOTICE +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/setup.cfg +0 -0
- {liger_kernel-0.4.2/src/liger_kernel/chunked_loss → liger_kernel-0.5.1/src/liger_kernel}/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -34,6 +34,8 @@ Requires-Dist: torch>=2.1.2
|
|
|
34
34
|
Requires-Dist: triton>=2.3.1
|
|
35
35
|
Provides-Extra: transformers
|
|
36
36
|
Requires-Dist: transformers~=4.0; extra == "transformers"
|
|
37
|
+
Provides-Extra: trl
|
|
38
|
+
Requires-Dist: trl>=0.11.0; extra == "trl"
|
|
37
39
|
Provides-Extra: dev
|
|
38
40
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
|
39
41
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
@@ -41,9 +43,16 @@ Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
|
|
41
43
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
|
42
44
|
Requires-Dist: isort>=5.13.2; extra == "dev"
|
|
43
45
|
Requires-Dist: pytest>=7.1.2; extra == "dev"
|
|
46
|
+
Requires-Dist: pytest-xdist; extra == "dev"
|
|
47
|
+
Requires-Dist: pytest-rerunfailures; extra == "dev"
|
|
44
48
|
Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
45
49
|
Requires-Dist: torchvision>=0.16.2; extra == "dev"
|
|
46
50
|
Requires-Dist: seaborn; extra == "dev"
|
|
51
|
+
Provides-Extra: amd
|
|
52
|
+
Requires-Dist: torch>=2.6.0.dev; extra == "amd"
|
|
53
|
+
Requires-Dist: setuptools-scm>=8; extra == "amd"
|
|
54
|
+
Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
|
|
55
|
+
Requires-Dist: triton>=3.0.0; extra == "amd"
|
|
47
56
|
|
|
48
57
|
<a name="readme-top"></a>
|
|
49
58
|
|
|
@@ -55,7 +64,7 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
55
64
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
56
65
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
57
66
|
<th style="padding: 10px;">Discord</th>
|
|
58
|
-
<th style="padding: 10px;">
|
|
67
|
+
<th style="padding: 10px;">Build</th>
|
|
59
68
|
</tr>
|
|
60
69
|
<tr>
|
|
61
70
|
<td style="padding: 10px;">
|
|
@@ -84,9 +93,16 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
84
93
|
</a>
|
|
85
94
|
</td>
|
|
86
95
|
<td style="padding: 10px;">
|
|
87
|
-
<
|
|
88
|
-
<
|
|
89
|
-
|
|
96
|
+
<div style="display: block;">
|
|
97
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
98
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
99
|
+
</a>
|
|
100
|
+
</div>
|
|
101
|
+
<div style="display: block;">
|
|
102
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
103
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
104
|
+
</a>
|
|
105
|
+
</div>
|
|
90
106
|
</td>
|
|
91
107
|
</tr>
|
|
92
108
|
</table>
|
|
@@ -95,13 +111,14 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
95
111
|
|
|
96
112
|
<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
|
|
97
113
|
|
|
98
|
-
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
|
|
114
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
|
|
99
115
|
|
|
100
116
|
<details>
|
|
101
117
|
<summary>Latest News 🔥</summary>
|
|
102
118
|
|
|
119
|
+
- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
103
120
|
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
|
|
104
|
-
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
121
|
+
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
105
122
|
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
|
|
106
123
|
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
|
|
107
124
|
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
|
|
@@ -111,6 +128,8 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
111
128
|
|
|
112
129
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
113
130
|
|
|
131
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
|
|
132
|
+
|
|
114
133
|
## Supercharge Your Model with Liger Kernel
|
|
115
134
|
|
|
116
135
|

|
|
@@ -128,12 +147,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
128
147
|
|
|
129
148
|
## Examples
|
|
130
149
|
|
|
131
|
-
|
|
132
150
|
| **Use Case** | **Description** |
|
|
133
151
|
|------------------------------------------------|---------------------------------------------------------------------------------------------------|
|
|
134
152
|
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
|
|
135
153
|
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
|
|
136
|
-
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
|
|
154
|
+
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
|
|
155
|
+
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
|
|
156
|
+
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
|
|
137
157
|
|
|
138
158
|
## Key Features
|
|
139
159
|
|
|
@@ -146,7 +166,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
146
166
|
|
|
147
167
|
## Installation
|
|
148
168
|
|
|
149
|
-
### Dependencies
|
|
169
|
+
### Dependencies
|
|
150
170
|
|
|
151
171
|
#### CUDA
|
|
152
172
|
|
|
@@ -183,6 +203,8 @@ To install from source:
|
|
|
183
203
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
|
184
204
|
cd Liger-Kernel
|
|
185
205
|
pip install -e .
|
|
206
|
+
# or if installing on amd platform
|
|
207
|
+
pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
|
|
186
208
|
# or if using transformers
|
|
187
209
|
pip install -e .[transformers]
|
|
188
210
|
```
|
|
@@ -249,7 +271,7 @@ loss = loss_fn(model.weight, input, target)
|
|
|
249
271
|
loss.backward()
|
|
250
272
|
```
|
|
251
273
|
|
|
252
|
-
## APIs
|
|
274
|
+
## High-level APIs
|
|
253
275
|
|
|
254
276
|
### AutoModel
|
|
255
277
|
|
|
@@ -268,13 +290,17 @@ loss.backward()
|
|
|
268
290
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
269
291
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
270
292
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
271
|
-
| Qwen2
|
|
293
|
+
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
294
|
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
273
295
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
274
296
|
|
|
275
297
|
|
|
298
|
+
## Low-level APIs
|
|
276
299
|
|
|
277
|
-
|
|
300
|
+
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
|
|
301
|
+
- Other kernels use fusion and in-place techniques for memory and performance optimization.
|
|
302
|
+
|
|
303
|
+
### Model Kernels
|
|
278
304
|
|
|
279
305
|
| **Kernel** | **API** |
|
|
280
306
|
|---------------------------------|-------------------------------------------------------------|
|
|
@@ -284,39 +310,33 @@ loss.backward()
|
|
|
284
310
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
285
311
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
286
312
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
287
|
-
|
|
|
313
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
### Alignment Kernels
|
|
317
|
+
|
|
318
|
+
| **Kernel** | **API** |
|
|
319
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
320
|
+
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
|
|
321
|
+
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
322
|
+
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
323
|
+
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
324
|
+
|
|
325
|
+
### Distillation Kernels
|
|
326
|
+
|
|
327
|
+
| **Kernel** | **API** |
|
|
328
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
288
329
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
289
330
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
290
|
-
|
|
|
291
|
-
|
|
292
|
-
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
293
|
-
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
|
|
294
|
-
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
|
|
295
|
-
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
296
|
-
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
297
|
-
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
|
|
298
|
-
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
|
|
299
|
-
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
300
|
-
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
301
|
-
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
|
|
302
|
-
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
|
|
303
|
-
<!-- TODO: verify vocab sizes are accurate -->
|
|
304
|
-
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
305
|
-
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
306
|
-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
307
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
308
|
-
|
|
331
|
+
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
309
332
|
|
|
310
333
|
### Experimental Kernels
|
|
311
334
|
|
|
312
335
|
| **Kernel** | **API** |
|
|
313
336
|
|---------------------------------|-------------------------------------------------------------|
|
|
314
337
|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
|
|
315
|
-
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
|
|
338
|
+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
|
|
316
339
|
|
|
317
|
-
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
|
|
318
|
-
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
|
|
319
|
-
<!-- TODO: be more specific about batch size -->
|
|
320
340
|
|
|
321
341
|
## Contributing, Acknowledgements, and License
|
|
322
342
|
|
|
@@ -324,6 +344,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
324
344
|
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
|
|
325
345
|
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
|
|
326
346
|
|
|
347
|
+
## Sponsorship and Collaboration
|
|
348
|
+
|
|
349
|
+
- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
|
|
350
|
+
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
|
|
351
|
+
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
|
|
352
|
+
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
|
|
353
|
+
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
|
|
354
|
+
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
|
|
355
|
+
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
356
|
+
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
357
|
+
|
|
327
358
|
## Contact
|
|
328
359
|
|
|
329
360
|
- For issues, create a Github ticket in this repository
|
|
@@ -335,7 +366,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
335
366
|
Biblatex entry:
|
|
336
367
|
```bib
|
|
337
368
|
@article{hsu2024ligerkernelefficienttriton,
|
|
338
|
-
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
369
|
+
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
339
370
|
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
|
340
371
|
year={2024},
|
|
341
372
|
eprint={2410.10989},
|
|
@@ -349,15 +380,8 @@ Biblatex entry:
|
|
|
349
380
|
## Star History
|
|
350
381
|
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
|
|
351
382
|
|
|
352
|
-
## Contributors
|
|
353
|
-
|
|
354
|
-
<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
|
|
355
|
-
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
|
|
356
|
-
</a>
|
|
357
|
-
|
|
358
383
|
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
|
359
384
|
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
|
360
385
|
↑ Back to Top ↑
|
|
361
386
|
</a>
|
|
362
387
|
</p>
|
|
363
|
-
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
9
9
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
10
10
|
<th style="padding: 10px;">Discord</th>
|
|
11
|
-
<th style="padding: 10px;">
|
|
11
|
+
<th style="padding: 10px;">Build</th>
|
|
12
12
|
</tr>
|
|
13
13
|
<tr>
|
|
14
14
|
<td style="padding: 10px;">
|
|
@@ -37,9 +37,16 @@
|
|
|
37
37
|
</a>
|
|
38
38
|
</td>
|
|
39
39
|
<td style="padding: 10px;">
|
|
40
|
-
<
|
|
41
|
-
<
|
|
42
|
-
|
|
40
|
+
<div style="display: block;">
|
|
41
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
42
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
43
|
+
</a>
|
|
44
|
+
</div>
|
|
45
|
+
<div style="display: block;">
|
|
46
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
47
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
48
|
+
</a>
|
|
49
|
+
</div>
|
|
43
50
|
</td>
|
|
44
51
|
</tr>
|
|
45
52
|
</table>
|
|
@@ -48,13 +55,14 @@
|
|
|
48
55
|
|
|
49
56
|
<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
|
|
50
57
|
|
|
51
|
-
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
|
|
58
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
|
|
52
59
|
|
|
53
60
|
<details>
|
|
54
61
|
<summary>Latest News 🔥</summary>
|
|
55
62
|
|
|
63
|
+
- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
56
64
|
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
|
|
57
|
-
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
65
|
+
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
58
66
|
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
|
|
59
67
|
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
|
|
60
68
|
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
|
|
@@ -64,6 +72,8 @@
|
|
|
64
72
|
|
|
65
73
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
66
74
|
|
|
75
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
|
|
76
|
+
|
|
67
77
|
## Supercharge Your Model with Liger Kernel
|
|
68
78
|
|
|
69
79
|

|
|
@@ -81,12 +91,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
81
91
|
|
|
82
92
|
## Examples
|
|
83
93
|
|
|
84
|
-
|
|
85
94
|
| **Use Case** | **Description** |
|
|
86
95
|
|------------------------------------------------|---------------------------------------------------------------------------------------------------|
|
|
87
96
|
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
|
|
88
97
|
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
|
|
89
|
-
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
|
|
98
|
+
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
|
|
99
|
+
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
|
|
100
|
+
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
|
|
90
101
|
|
|
91
102
|
## Key Features
|
|
92
103
|
|
|
@@ -99,7 +110,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
99
110
|
|
|
100
111
|
## Installation
|
|
101
112
|
|
|
102
|
-
### Dependencies
|
|
113
|
+
### Dependencies
|
|
103
114
|
|
|
104
115
|
#### CUDA
|
|
105
116
|
|
|
@@ -136,6 +147,8 @@ To install from source:
|
|
|
136
147
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
|
137
148
|
cd Liger-Kernel
|
|
138
149
|
pip install -e .
|
|
150
|
+
# or if installing on amd platform
|
|
151
|
+
pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
|
|
139
152
|
# or if using transformers
|
|
140
153
|
pip install -e .[transformers]
|
|
141
154
|
```
|
|
@@ -202,7 +215,7 @@ loss = loss_fn(model.weight, input, target)
|
|
|
202
215
|
loss.backward()
|
|
203
216
|
```
|
|
204
217
|
|
|
205
|
-
## APIs
|
|
218
|
+
## High-level APIs
|
|
206
219
|
|
|
207
220
|
### AutoModel
|
|
208
221
|
|
|
@@ -221,13 +234,17 @@ loss.backward()
|
|
|
221
234
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
222
235
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
223
236
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
224
|
-
| Qwen2
|
|
237
|
+
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
225
238
|
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
226
239
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
227
240
|
|
|
228
241
|
|
|
242
|
+
## Low-level APIs
|
|
229
243
|
|
|
230
|
-
|
|
244
|
+
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
|
|
245
|
+
- Other kernels use fusion and in-place techniques for memory and performance optimization.
|
|
246
|
+
|
|
247
|
+
### Model Kernels
|
|
231
248
|
|
|
232
249
|
| **Kernel** | **API** |
|
|
233
250
|
|---------------------------------|-------------------------------------------------------------|
|
|
@@ -237,39 +254,33 @@ loss.backward()
|
|
|
237
254
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
238
255
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
239
256
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
240
|
-
|
|
|
257
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
### Alignment Kernels
|
|
261
|
+
|
|
262
|
+
| **Kernel** | **API** |
|
|
263
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
264
|
+
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
|
|
265
|
+
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
266
|
+
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
267
|
+
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
268
|
+
|
|
269
|
+
### Distillation Kernels
|
|
270
|
+
|
|
271
|
+
| **Kernel** | **API** |
|
|
272
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
241
273
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
242
274
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
243
|
-
|
|
|
244
|
-
|
|
245
|
-
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
246
|
-
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
|
|
247
|
-
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
|
|
248
|
-
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
249
|
-
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
250
|
-
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
|
|
251
|
-
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
|
|
252
|
-
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
253
|
-
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
254
|
-
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
|
|
255
|
-
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
|
|
256
|
-
<!-- TODO: verify vocab sizes are accurate -->
|
|
257
|
-
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
258
|
-
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
259
|
-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
260
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
261
|
-
|
|
275
|
+
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
262
276
|
|
|
263
277
|
### Experimental Kernels
|
|
264
278
|
|
|
265
279
|
| **Kernel** | **API** |
|
|
266
280
|
|---------------------------------|-------------------------------------------------------------|
|
|
267
281
|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
|
|
268
|
-
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
|
|
282
|
+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
|
|
269
283
|
|
|
270
|
-
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
|
|
271
|
-
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
|
|
272
|
-
<!-- TODO: be more specific about batch size -->
|
|
273
284
|
|
|
274
285
|
## Contributing, Acknowledgements, and License
|
|
275
286
|
|
|
@@ -277,6 +288,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
277
288
|
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
|
|
278
289
|
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
|
|
279
290
|
|
|
291
|
+
## Sponsorship and Collaboration
|
|
292
|
+
|
|
293
|
+
- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
|
|
294
|
+
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
|
|
295
|
+
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
|
|
296
|
+
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
|
|
297
|
+
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
|
|
298
|
+
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
|
|
299
|
+
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
300
|
+
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
301
|
+
|
|
280
302
|
## Contact
|
|
281
303
|
|
|
282
304
|
- For issues, create a Github ticket in this repository
|
|
@@ -288,7 +310,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
288
310
|
Biblatex entry:
|
|
289
311
|
```bib
|
|
290
312
|
@article{hsu2024ligerkernelefficienttriton,
|
|
291
|
-
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
313
|
+
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
292
314
|
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
|
293
315
|
year={2024},
|
|
294
316
|
eprint={2410.10989},
|
|
@@ -302,15 +324,8 @@ Biblatex entry:
|
|
|
302
324
|
## Star History
|
|
303
325
|
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
|
|
304
326
|
|
|
305
|
-
## Contributors
|
|
306
|
-
|
|
307
|
-
<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
|
|
308
|
-
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
|
|
309
|
-
</a>
|
|
310
|
-
|
|
311
327
|
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
|
312
328
|
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
|
313
329
|
↑ Back to Top ↑
|
|
314
330
|
</a>
|
|
315
331
|
</p>
|
|
316
|
-
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.5.1"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -14,11 +14,16 @@ dependencies = [
|
|
|
14
14
|
"triton>=2.3.1",
|
|
15
15
|
]
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
[project.optional-dependencies]
|
|
18
19
|
transformers = [
|
|
19
20
|
"transformers~=4.0"
|
|
20
21
|
]
|
|
21
22
|
|
|
23
|
+
trl = [
|
|
24
|
+
"trl>=0.11.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
22
27
|
dev = [
|
|
23
28
|
"transformers>=4.44.2",
|
|
24
29
|
"matplotlib>=3.7.2",
|
|
@@ -26,11 +31,22 @@ dev = [
|
|
|
26
31
|
"black>=24.4.2",
|
|
27
32
|
"isort>=5.13.2",
|
|
28
33
|
"pytest>=7.1.2",
|
|
34
|
+
"pytest-xdist",
|
|
35
|
+
"pytest-rerunfailures",
|
|
29
36
|
"datasets>=2.19.2",
|
|
30
37
|
"torchvision>=0.16.2",
|
|
31
38
|
"seaborn",
|
|
32
39
|
]
|
|
33
40
|
|
|
41
|
+
amd = [
|
|
42
|
+
"torch>=2.6.0.dev",
|
|
43
|
+
"setuptools-scm>=8",
|
|
44
|
+
"torchvision>=0.20.0.dev",
|
|
45
|
+
"triton>=3.0.0",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
34
50
|
[tool.setuptools.packages.find]
|
|
35
51
|
where = ["src"]
|
|
36
52
|
include = ["liger_kernel", "liger_kernel.*"]
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
4
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
5
|
+
LigerFusedLinearPreferenceBase,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
13
|
+
"""
|
|
14
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
|
15
|
+
|
|
16
|
+
Formula:
|
|
17
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
|
18
|
+
|
|
19
|
+
Where:
|
|
20
|
+
- π_θ(y|x): Policy (model) probability
|
|
21
|
+
- y_w: Chosen sequence
|
|
22
|
+
- y_l: Rejected sequence
|
|
23
|
+
- σ: Sigmoid function
|
|
24
|
+
- β: Temperature parameter
|
|
25
|
+
- E: Expected value over the dataset D
|
|
26
|
+
- D: Dataset of preferences
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
30
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
|
32
|
+
beta (float): Weight for the CPO loss
|
|
33
|
+
"""
|
|
34
|
+
logits = beta * (chosen_logps - rejected_logps)
|
|
35
|
+
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
|
36
|
+
return loss
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def forward(
|
|
40
|
+
ctx,
|
|
41
|
+
_input,
|
|
42
|
+
weight,
|
|
43
|
+
target,
|
|
44
|
+
bias=None,
|
|
45
|
+
ignore_index=-100,
|
|
46
|
+
beta=0.1,
|
|
47
|
+
alpha=1.0,
|
|
48
|
+
compute_nll_loss=True,
|
|
49
|
+
compiled=True,
|
|
50
|
+
):
|
|
51
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
52
|
+
ctx,
|
|
53
|
+
_input,
|
|
54
|
+
weight,
|
|
55
|
+
target,
|
|
56
|
+
bias,
|
|
57
|
+
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
|
|
58
|
+
ignore_index=ignore_index,
|
|
59
|
+
alpha=alpha,
|
|
60
|
+
beta=beta,
|
|
61
|
+
compute_nll_loss=compute_nll_loss,
|
|
62
|
+
compiled=compiled,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def backward(ctx, *grad_output):
|
|
67
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
68
|
+
return *grads, None, None, None, None, None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Fused linear layer with CPO loss.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
ignore_index: int = -100,
|
|
79
|
+
beta: float = 0.1,
|
|
80
|
+
alpha: float = 1.0,
|
|
81
|
+
compute_nll_loss: bool = True,
|
|
82
|
+
compiled: bool = True,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Args:
|
|
86
|
+
ignore_index (int): Index to ignore in the loss.
|
|
87
|
+
beta (float): Weight for the odds ratio loss.
|
|
88
|
+
"""
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.ignore_index = ignore_index
|
|
91
|
+
self.beta = beta
|
|
92
|
+
self.alpha = alpha
|
|
93
|
+
self.compute_nll_loss = compute_nll_loss
|
|
94
|
+
self.compiled = compiled
|
|
95
|
+
|
|
96
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
97
|
+
return LigerFusedLinearCPOFunction.apply(
|
|
98
|
+
_input,
|
|
99
|
+
lin_weight,
|
|
100
|
+
target,
|
|
101
|
+
bias,
|
|
102
|
+
self.ignore_index,
|
|
103
|
+
self.beta,
|
|
104
|
+
self.alpha,
|
|
105
|
+
self.compute_nll_loss,
|
|
106
|
+
self.compiled,
|
|
107
|
+
)
|