liger-kernel 0.4.2__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel-0.4.2/src/liger_kernel.egg-info → liger_kernel-0.5.0}/PKG-INFO +68 -45
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/README.md +59 -44
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/pyproject.toml +14 -1
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/__init__.py +4 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/cpo_loss.py +107 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/dpo_loss.py +135 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/functional.py +9 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/orpo_loss.py +113 -0
- liger_kernel-0.5.0/src/liger_kernel/chunked_loss/simpo_loss.py +115 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/env_report.py +22 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/cross_entropy.py +17 -10
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/fused_linear_jsd.py +1 -1
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/jsd.py +19 -10
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/layer_norm.py +6 -1
- liger_kernel-0.5.0/src/liger_kernel/ops/qwen2vl_mrope.py +238 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/rms_norm.py +6 -1
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/utils.py +5 -2
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/__init__.py +1 -0
- liger_kernel-0.5.0/src/liger_kernel/transformers/functional.py +173 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/fused_linear_jsd.py +1 -4
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/jsd.py +1 -4
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/monkey_patch.py +6 -4
- liger_kernel-0.5.0/src/liger_kernel/transformers/orpo_trainer.py +171 -0
- liger_kernel-0.5.0/src/liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel-0.5.0/src/liger_kernel/utils.py +13 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0/src/liger_kernel.egg-info}/PKG-INFO +68 -45
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel.egg-info/SOURCES.txt +9 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel.egg-info/requires.txt +9 -0
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/dpo_loss.py +0 -57
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -206
- liger_kernel-0.4.2/src/liger_kernel/chunked_loss/orpo_loss.py +0 -63
- liger_kernel-0.4.2/src/liger_kernel/transformers/functional.py +0 -56
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/LICENSE +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/NOTICE +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/setup.cfg +0 -0
- {liger_kernel-0.4.2/src/liger_kernel/chunked_loss → liger_kernel-0.5.0/src/liger_kernel}/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.4.2 → liger_kernel-0.5.0}/src/liger_kernel.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -36,14 +36,22 @@ Provides-Extra: transformers
|
|
|
36
36
|
Requires-Dist: transformers~=4.0; extra == "transformers"
|
|
37
37
|
Provides-Extra: dev
|
|
38
38
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
|
39
|
+
Requires-Dist: trl>=0.11.0; extra == "dev"
|
|
39
40
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
40
41
|
Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
|
41
42
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
|
42
43
|
Requires-Dist: isort>=5.13.2; extra == "dev"
|
|
43
44
|
Requires-Dist: pytest>=7.1.2; extra == "dev"
|
|
45
|
+
Requires-Dist: pytest-xdist; extra == "dev"
|
|
46
|
+
Requires-Dist: pytest-rerunfailures; extra == "dev"
|
|
44
47
|
Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
45
48
|
Requires-Dist: torchvision>=0.16.2; extra == "dev"
|
|
46
49
|
Requires-Dist: seaborn; extra == "dev"
|
|
50
|
+
Provides-Extra: amd
|
|
51
|
+
Requires-Dist: torch>=2.6.0.dev; extra == "amd"
|
|
52
|
+
Requires-Dist: setuptools-scm>=8; extra == "amd"
|
|
53
|
+
Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
|
|
54
|
+
Requires-Dist: triton>=3.0.0; extra == "amd"
|
|
47
55
|
|
|
48
56
|
<a name="readme-top"></a>
|
|
49
57
|
|
|
@@ -55,7 +63,7 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
55
63
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
56
64
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
57
65
|
<th style="padding: 10px;">Discord</th>
|
|
58
|
-
<th style="padding: 10px;">
|
|
66
|
+
<th style="padding: 10px;">Build</th>
|
|
59
67
|
</tr>
|
|
60
68
|
<tr>
|
|
61
69
|
<td style="padding: 10px;">
|
|
@@ -84,9 +92,16 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
84
92
|
</a>
|
|
85
93
|
</td>
|
|
86
94
|
<td style="padding: 10px;">
|
|
87
|
-
<
|
|
88
|
-
<
|
|
89
|
-
|
|
95
|
+
<div style="display: block;">
|
|
96
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
97
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
98
|
+
</a>
|
|
99
|
+
</div>
|
|
100
|
+
<div style="display: block;">
|
|
101
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
102
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
103
|
+
</a>
|
|
104
|
+
</div>
|
|
90
105
|
</td>
|
|
91
106
|
</tr>
|
|
92
107
|
</table>
|
|
@@ -95,13 +110,14 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
95
110
|
|
|
96
111
|
<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
|
|
97
112
|
|
|
98
|
-
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
|
|
113
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
|
|
99
114
|
|
|
100
115
|
<details>
|
|
101
116
|
<summary>Latest News 🔥</summary>
|
|
102
117
|
|
|
118
|
+
- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
103
119
|
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
|
|
104
|
-
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
120
|
+
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
105
121
|
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
|
|
106
122
|
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
|
|
107
123
|
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
|
|
@@ -111,6 +127,8 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
111
127
|
|
|
112
128
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
113
129
|
|
|
130
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
|
|
131
|
+
|
|
114
132
|
## Supercharge Your Model with Liger Kernel
|
|
115
133
|
|
|
116
134
|

|
|
@@ -128,12 +146,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
128
146
|
|
|
129
147
|
## Examples
|
|
130
148
|
|
|
131
|
-
|
|
132
149
|
| **Use Case** | **Description** |
|
|
133
150
|
|------------------------------------------------|---------------------------------------------------------------------------------------------------|
|
|
134
151
|
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
|
|
135
152
|
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
|
|
136
|
-
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
|
|
153
|
+
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
|
|
154
|
+
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
|
|
155
|
+
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
|
|
137
156
|
|
|
138
157
|
## Key Features
|
|
139
158
|
|
|
@@ -146,7 +165,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
146
165
|
|
|
147
166
|
## Installation
|
|
148
167
|
|
|
149
|
-
### Dependencies
|
|
168
|
+
### Dependencies
|
|
150
169
|
|
|
151
170
|
#### CUDA
|
|
152
171
|
|
|
@@ -183,6 +202,8 @@ To install from source:
|
|
|
183
202
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
|
184
203
|
cd Liger-Kernel
|
|
185
204
|
pip install -e .
|
|
205
|
+
# or if installing on amd platform
|
|
206
|
+
pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
|
|
186
207
|
# or if using transformers
|
|
187
208
|
pip install -e .[transformers]
|
|
188
209
|
```
|
|
@@ -249,7 +270,7 @@ loss = loss_fn(model.weight, input, target)
|
|
|
249
270
|
loss.backward()
|
|
250
271
|
```
|
|
251
272
|
|
|
252
|
-
## APIs
|
|
273
|
+
## High-level APIs
|
|
253
274
|
|
|
254
275
|
### AutoModel
|
|
255
276
|
|
|
@@ -268,13 +289,17 @@ loss.backward()
|
|
|
268
289
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
269
290
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
270
291
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
271
|
-
| Qwen2
|
|
292
|
+
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
293
|
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
273
294
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
274
295
|
|
|
275
296
|
|
|
297
|
+
## Low-level APIs
|
|
276
298
|
|
|
277
|
-
|
|
299
|
+
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
|
|
300
|
+
- Other kernels use fusion and in-place techniques for memory and performance optimization.
|
|
301
|
+
|
|
302
|
+
### Model Kernels
|
|
278
303
|
|
|
279
304
|
| **Kernel** | **API** |
|
|
280
305
|
|---------------------------------|-------------------------------------------------------------|
|
|
@@ -284,39 +309,33 @@ loss.backward()
|
|
|
284
309
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
285
310
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
286
311
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
287
|
-
|
|
|
312
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
### Alignment Kernels
|
|
316
|
+
|
|
317
|
+
| **Kernel** | **API** |
|
|
318
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
319
|
+
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
|
|
320
|
+
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
321
|
+
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
322
|
+
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
323
|
+
|
|
324
|
+
### Distillation Kernels
|
|
325
|
+
|
|
326
|
+
| **Kernel** | **API** |
|
|
327
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
288
328
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
289
329
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
290
|
-
|
|
|
291
|
-
|
|
292
|
-
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
293
|
-
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
|
|
294
|
-
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
|
|
295
|
-
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
296
|
-
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
297
|
-
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
|
|
298
|
-
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
|
|
299
|
-
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
300
|
-
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
301
|
-
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
|
|
302
|
-
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
|
|
303
|
-
<!-- TODO: verify vocab sizes are accurate -->
|
|
304
|
-
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
305
|
-
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
306
|
-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
307
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
308
|
-
|
|
330
|
+
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
309
331
|
|
|
310
332
|
### Experimental Kernels
|
|
311
333
|
|
|
312
334
|
| **Kernel** | **API** |
|
|
313
335
|
|---------------------------------|-------------------------------------------------------------|
|
|
314
336
|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
|
|
315
|
-
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
|
|
337
|
+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
|
|
316
338
|
|
|
317
|
-
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
|
|
318
|
-
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
|
|
319
|
-
<!-- TODO: be more specific about batch size -->
|
|
320
339
|
|
|
321
340
|
## Contributing, Acknowledgements, and License
|
|
322
341
|
|
|
@@ -324,6 +343,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
324
343
|
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
|
|
325
344
|
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
|
|
326
345
|
|
|
346
|
+
## Sponsorship and Collaboration
|
|
347
|
+
|
|
348
|
+
- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
|
|
349
|
+
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
|
|
350
|
+
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
|
|
351
|
+
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
|
|
352
|
+
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
|
|
353
|
+
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
|
|
354
|
+
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
355
|
+
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
356
|
+
|
|
327
357
|
## Contact
|
|
328
358
|
|
|
329
359
|
- For issues, create a Github ticket in this repository
|
|
@@ -335,7 +365,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
335
365
|
Biblatex entry:
|
|
336
366
|
```bib
|
|
337
367
|
@article{hsu2024ligerkernelefficienttriton,
|
|
338
|
-
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
368
|
+
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
339
369
|
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
|
340
370
|
year={2024},
|
|
341
371
|
eprint={2410.10989},
|
|
@@ -349,15 +379,8 @@ Biblatex entry:
|
|
|
349
379
|
## Star History
|
|
350
380
|
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
|
|
351
381
|
|
|
352
|
-
## Contributors
|
|
353
|
-
|
|
354
|
-
<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
|
|
355
|
-
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
|
|
356
|
-
</a>
|
|
357
|
-
|
|
358
382
|
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
|
359
383
|
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
|
360
384
|
↑ Back to Top ↑
|
|
361
385
|
</a>
|
|
362
386
|
</p>
|
|
363
|
-
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
9
9
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
10
10
|
<th style="padding: 10px;">Discord</th>
|
|
11
|
-
<th style="padding: 10px;">
|
|
11
|
+
<th style="padding: 10px;">Build</th>
|
|
12
12
|
</tr>
|
|
13
13
|
<tr>
|
|
14
14
|
<td style="padding: 10px;">
|
|
@@ -37,9 +37,16 @@
|
|
|
37
37
|
</a>
|
|
38
38
|
</td>
|
|
39
39
|
<td style="padding: 10px;">
|
|
40
|
-
<
|
|
41
|
-
<
|
|
42
|
-
|
|
40
|
+
<div style="display: block;">
|
|
41
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
42
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
43
|
+
</a>
|
|
44
|
+
</div>
|
|
45
|
+
<div style="display: block;">
|
|
46
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
47
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
48
|
+
</a>
|
|
49
|
+
</div>
|
|
43
50
|
</td>
|
|
44
51
|
</tr>
|
|
45
52
|
</table>
|
|
@@ -48,13 +55,14 @@
|
|
|
48
55
|
|
|
49
56
|
<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
|
|
50
57
|
|
|
51
|
-
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
|
|
58
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
|
|
52
59
|
|
|
53
60
|
<details>
|
|
54
61
|
<summary>Latest News 🔥</summary>
|
|
55
62
|
|
|
63
|
+
- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
56
64
|
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
|
|
57
|
-
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
65
|
+
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
|
|
58
66
|
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
|
|
59
67
|
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
|
|
60
68
|
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
|
|
@@ -64,6 +72,8 @@
|
|
|
64
72
|
|
|
65
73
|
**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
|
|
66
74
|
|
|
75
|
+
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
|
|
76
|
+
|
|
67
77
|
## Supercharge Your Model with Liger Kernel
|
|
68
78
|
|
|
69
79
|

|
|
@@ -81,12 +91,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
81
91
|
|
|
82
92
|
## Examples
|
|
83
93
|
|
|
84
|
-
|
|
85
94
|
| **Use Case** | **Description** |
|
|
86
95
|
|------------------------------------------------|---------------------------------------------------------------------------------------------------|
|
|
87
96
|
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
|
|
88
97
|
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
|
|
89
|
-
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
|
|
98
|
+
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
|
|
99
|
+
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
|
|
100
|
+
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
|
|
90
101
|
|
|
91
102
|
## Key Features
|
|
92
103
|
|
|
@@ -99,7 +110,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
99
110
|
|
|
100
111
|
## Installation
|
|
101
112
|
|
|
102
|
-
### Dependencies
|
|
113
|
+
### Dependencies
|
|
103
114
|
|
|
104
115
|
#### CUDA
|
|
105
116
|
|
|
@@ -136,6 +147,8 @@ To install from source:
|
|
|
136
147
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
|
137
148
|
cd Liger-Kernel
|
|
138
149
|
pip install -e .
|
|
150
|
+
# or if installing on amd platform
|
|
151
|
+
pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
|
|
139
152
|
# or if using transformers
|
|
140
153
|
pip install -e .[transformers]
|
|
141
154
|
```
|
|
@@ -202,7 +215,7 @@ loss = loss_fn(model.weight, input, target)
|
|
|
202
215
|
loss.backward()
|
|
203
216
|
```
|
|
204
217
|
|
|
205
|
-
## APIs
|
|
218
|
+
## High-level APIs
|
|
206
219
|
|
|
207
220
|
### AutoModel
|
|
208
221
|
|
|
@@ -221,13 +234,17 @@ loss.backward()
|
|
|
221
234
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
222
235
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
223
236
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
224
|
-
| Qwen2
|
|
237
|
+
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
225
238
|
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
226
239
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
227
240
|
|
|
228
241
|
|
|
242
|
+
## Low-level APIs
|
|
229
243
|
|
|
230
|
-
|
|
244
|
+
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
|
|
245
|
+
- Other kernels use fusion and in-place techniques for memory and performance optimization.
|
|
246
|
+
|
|
247
|
+
### Model Kernels
|
|
231
248
|
|
|
232
249
|
| **Kernel** | **API** |
|
|
233
250
|
|---------------------------------|-------------------------------------------------------------|
|
|
@@ -237,39 +254,33 @@ loss.backward()
|
|
|
237
254
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
238
255
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
239
256
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
240
|
-
|
|
|
257
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
### Alignment Kernels
|
|
261
|
+
|
|
262
|
+
| **Kernel** | **API** |
|
|
263
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
264
|
+
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
|
|
265
|
+
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
|
|
266
|
+
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
|
|
267
|
+
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
|
|
268
|
+
|
|
269
|
+
### Distillation Kernels
|
|
270
|
+
|
|
271
|
+
| **Kernel** | **API** |
|
|
272
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
241
273
|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
|
|
242
274
|
| JSD | `liger_kernel.transformers.LigerJSD` |
|
|
243
|
-
|
|
|
244
|
-
|
|
245
|
-
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
246
|
-
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
|
|
247
|
-
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
|
|
248
|
-
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
|
|
249
|
-
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
250
|
-
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
|
|
251
|
-
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
|
|
252
|
-
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
|
|
253
|
-
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
254
|
-
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
|
|
255
|
-
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
|
|
256
|
-
<!-- TODO: verify vocab sizes are accurate -->
|
|
257
|
-
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
258
|
-
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
|
|
259
|
-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
|
|
260
|
-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
|
|
261
|
-
|
|
275
|
+
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
|
|
262
276
|
|
|
263
277
|
### Experimental Kernels
|
|
264
278
|
|
|
265
279
|
| **Kernel** | **API** |
|
|
266
280
|
|---------------------------------|-------------------------------------------------------------|
|
|
267
281
|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
|
|
268
|
-
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
|
|
282
|
+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
|
|
269
283
|
|
|
270
|
-
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
|
|
271
|
-
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
|
|
272
|
-
<!-- TODO: be more specific about batch size -->
|
|
273
284
|
|
|
274
285
|
## Contributing, Acknowledgements, and License
|
|
275
286
|
|
|
@@ -277,6 +288,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
277
288
|
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
|
|
278
289
|
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
|
|
279
290
|
|
|
291
|
+
## Sponsorship and Collaboration
|
|
292
|
+
|
|
293
|
+
- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
|
|
294
|
+
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
|
|
295
|
+
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
|
|
296
|
+
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
|
|
297
|
+
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
|
|
298
|
+
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
|
|
299
|
+
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
300
|
+
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
301
|
+
|
|
280
302
|
## Contact
|
|
281
303
|
|
|
282
304
|
- For issues, create a Github ticket in this repository
|
|
@@ -288,7 +310,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
|
|
|
288
310
|
Biblatex entry:
|
|
289
311
|
```bib
|
|
290
312
|
@article{hsu2024ligerkernelefficienttriton,
|
|
291
|
-
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
313
|
+
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
|
292
314
|
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
|
293
315
|
year={2024},
|
|
294
316
|
eprint={2410.10989},
|
|
@@ -302,15 +324,8 @@ Biblatex entry:
|
|
|
302
324
|
## Star History
|
|
303
325
|
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
|
|
304
326
|
|
|
305
|
-
## Contributors
|
|
306
|
-
|
|
307
|
-
<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
|
|
308
|
-
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
|
|
309
|
-
</a>
|
|
310
|
-
|
|
311
327
|
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
|
312
328
|
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
|
313
329
|
↑ Back to Top ↑
|
|
314
330
|
</a>
|
|
315
331
|
</p>
|
|
316
|
-
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.5.0"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -14,6 +14,7 @@ dependencies = [
|
|
|
14
14
|
"triton>=2.3.1",
|
|
15
15
|
]
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
[project.optional-dependencies]
|
|
18
19
|
transformers = [
|
|
19
20
|
"transformers~=4.0"
|
|
@@ -21,16 +22,28 @@ transformers = [
|
|
|
21
22
|
|
|
22
23
|
dev = [
|
|
23
24
|
"transformers>=4.44.2",
|
|
25
|
+
"trl>=0.11.0",
|
|
24
26
|
"matplotlib>=3.7.2",
|
|
25
27
|
"flake8>=4.0.1.1",
|
|
26
28
|
"black>=24.4.2",
|
|
27
29
|
"isort>=5.13.2",
|
|
28
30
|
"pytest>=7.1.2",
|
|
31
|
+
"pytest-xdist",
|
|
32
|
+
"pytest-rerunfailures",
|
|
29
33
|
"datasets>=2.19.2",
|
|
30
34
|
"torchvision>=0.16.2",
|
|
31
35
|
"seaborn",
|
|
32
36
|
]
|
|
33
37
|
|
|
38
|
+
amd = [
|
|
39
|
+
"torch>=2.6.0.dev",
|
|
40
|
+
"setuptools-scm>=8",
|
|
41
|
+
"torchvision>=0.20.0.dev",
|
|
42
|
+
"triton>=3.0.0",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
|
|
34
47
|
[tool.setuptools.packages.find]
|
|
35
48
|
where = ["src"]
|
|
36
49
|
include = ["liger_kernel", "liger_kernel.*"]
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
4
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
5
|
+
LigerFusedLinearPreferenceBase,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
13
|
+
"""
|
|
14
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
|
15
|
+
|
|
16
|
+
Formula:
|
|
17
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
|
18
|
+
|
|
19
|
+
Where:
|
|
20
|
+
- π_θ(y|x): Policy (model) probability
|
|
21
|
+
- y_w: Chosen sequence
|
|
22
|
+
- y_l: Rejected sequence
|
|
23
|
+
- σ: Sigmoid function
|
|
24
|
+
- β: Temperature parameter
|
|
25
|
+
- E: Expected value over the dataset D
|
|
26
|
+
- D: Dataset of preferences
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
30
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
|
32
|
+
beta (float): Weight for the CPO loss
|
|
33
|
+
"""
|
|
34
|
+
logits = beta * (chosen_logps - rejected_logps)
|
|
35
|
+
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
|
36
|
+
return loss
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def forward(
|
|
40
|
+
ctx,
|
|
41
|
+
_input,
|
|
42
|
+
weight,
|
|
43
|
+
target,
|
|
44
|
+
bias=None,
|
|
45
|
+
ignore_index=-100,
|
|
46
|
+
beta=0.1,
|
|
47
|
+
alpha=1.0,
|
|
48
|
+
compute_nll_loss=True,
|
|
49
|
+
compiled=True,
|
|
50
|
+
):
|
|
51
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
52
|
+
ctx,
|
|
53
|
+
_input,
|
|
54
|
+
weight,
|
|
55
|
+
target,
|
|
56
|
+
bias,
|
|
57
|
+
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
|
|
58
|
+
ignore_index=ignore_index,
|
|
59
|
+
alpha=alpha,
|
|
60
|
+
beta=beta,
|
|
61
|
+
compute_nll_loss=compute_nll_loss,
|
|
62
|
+
compiled=compiled,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def backward(ctx, *grad_output):
|
|
67
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
68
|
+
return *grads, None, None, None, None, None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Fused linear layer with CPO loss.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
ignore_index: int = -100,
|
|
79
|
+
beta: float = 0.1,
|
|
80
|
+
alpha: float = 1.0,
|
|
81
|
+
compute_nll_loss: bool = True,
|
|
82
|
+
compiled: bool = True,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Args:
|
|
86
|
+
ignore_index (int): Index to ignore in the loss.
|
|
87
|
+
beta (float): Weight for the odds ratio loss.
|
|
88
|
+
"""
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.ignore_index = ignore_index
|
|
91
|
+
self.beta = beta
|
|
92
|
+
self.alpha = alpha
|
|
93
|
+
self.compute_nll_loss = compute_nll_loss
|
|
94
|
+
self.compiled = compiled
|
|
95
|
+
|
|
96
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
97
|
+
return LigerFusedLinearCPOFunction.apply(
|
|
98
|
+
_input,
|
|
99
|
+
lin_weight,
|
|
100
|
+
target,
|
|
101
|
+
bias,
|
|
102
|
+
self.ignore_index,
|
|
103
|
+
self.beta,
|
|
104
|
+
self.alpha,
|
|
105
|
+
self.compute_nll_loss,
|
|
106
|
+
self.compiled,
|
|
107
|
+
)
|