liger-kernel 0.4.2__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {liger_kernel-0.4.2/src/liger_kernel.egg-info → liger_kernel-0.5.1}/PKG-INFO +69 -45
  2. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/README.md +59 -44
  3. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/pyproject.toml +17 -1
  4. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/__init__.py +4 -0
  5. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/cpo_loss.py +107 -0
  6. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/dpo_loss.py +135 -0
  7. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/functional.py +9 -0
  8. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  9. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
  10. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/orpo_loss.py +113 -0
  11. liger_kernel-0.5.1/src/liger_kernel/chunked_loss/simpo_loss.py +115 -0
  12. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/env_report.py +22 -0
  13. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/cross_entropy.py +17 -10
  14. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
  15. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/fused_linear_jsd.py +1 -1
  16. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/jsd.py +19 -10
  17. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/layer_norm.py +6 -1
  18. liger_kernel-0.5.1/src/liger_kernel/ops/qwen2vl_mrope.py +238 -0
  19. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/rms_norm.py +6 -1
  20. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/utils.py +5 -2
  21. liger_kernel-0.5.1/src/liger_kernel/transformers/functional.py +173 -0
  22. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/fused_linear_jsd.py +1 -4
  23. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/jsd.py +1 -4
  24. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/monkey_patch.py +6 -4
  25. liger_kernel-0.5.1/src/liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  26. liger_kernel-0.5.1/src/liger_kernel/transformers/trainer/__init__.py +6 -0
  27. liger_kernel-0.5.1/src/liger_kernel/transformers/trainer/orpo_trainer.py +169 -0
  28. liger_kernel-0.5.1/src/liger_kernel/utils.py +13 -0
  29. {liger_kernel-0.4.2 → liger_kernel-0.5.1/src/liger_kernel.egg-info}/PKG-INFO +69 -45
  30. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/SOURCES.txt +10 -0
  31. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/requires.txt +11 -0
  32. liger_kernel-0.4.2/src/liger_kernel/chunked_loss/dpo_loss.py +0 -57
  33. liger_kernel-0.4.2/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -206
  34. liger_kernel-0.4.2/src/liger_kernel/chunked_loss/orpo_loss.py +0 -63
  35. liger_kernel-0.4.2/src/liger_kernel/transformers/functional.py +0 -56
  36. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/LICENSE +0 -0
  37. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/NOTICE +0 -0
  38. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/setup.cfg +0 -0
  39. {liger_kernel-0.4.2/src/liger_kernel/chunked_loss → liger_kernel-0.5.1/src/liger_kernel}/__init__.py +0 -0
  40. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/__init__.py +0 -0
  41. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  42. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  43. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/geglu.py +0 -0
  44. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/group_norm.py +0 -0
  45. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/kl_div.py +0 -0
  46. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/rope.py +0 -0
  47. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/ops/swiglu.py +0 -0
  48. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/__init__.py +0 -0
  49. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/auto_model.py +0 -0
  50. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  51. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  52. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  53. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/geglu.py +0 -0
  54. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/group_norm.py +0 -0
  55. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/kl_div.py +0 -0
  56. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/layer_norm.py +0 -0
  57. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
  58. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/gemma.py +0 -0
  59. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  60. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/llama.py +0 -0
  61. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mistral.py +0 -0
  62. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  63. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/mllama.py +0 -0
  64. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/phi3.py +0 -0
  65. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  66. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  67. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/rms_norm.py +0 -0
  68. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/rope.py +0 -0
  69. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/swiglu.py +0 -0
  70. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  71. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/triton/__init__.py +0 -0
  72. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
  73. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  74. {liger_kernel-0.4.2 → liger_kernel-0.5.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.4.2
3
+ Version: 0.5.1
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -34,6 +34,8 @@ Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: transformers
36
36
  Requires-Dist: transformers~=4.0; extra == "transformers"
37
+ Provides-Extra: trl
38
+ Requires-Dist: trl>=0.11.0; extra == "trl"
37
39
  Provides-Extra: dev
38
40
  Requires-Dist: transformers>=4.44.2; extra == "dev"
39
41
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -41,9 +43,16 @@ Requires-Dist: flake8>=4.0.1.1; extra == "dev"
41
43
  Requires-Dist: black>=24.4.2; extra == "dev"
42
44
  Requires-Dist: isort>=5.13.2; extra == "dev"
43
45
  Requires-Dist: pytest>=7.1.2; extra == "dev"
46
+ Requires-Dist: pytest-xdist; extra == "dev"
47
+ Requires-Dist: pytest-rerunfailures; extra == "dev"
44
48
  Requires-Dist: datasets>=2.19.2; extra == "dev"
45
49
  Requires-Dist: torchvision>=0.16.2; extra == "dev"
46
50
  Requires-Dist: seaborn; extra == "dev"
51
+ Provides-Extra: amd
52
+ Requires-Dist: torch>=2.6.0.dev; extra == "amd"
53
+ Requires-Dist: setuptools-scm>=8; extra == "amd"
54
+ Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
55
+ Requires-Dist: triton>=3.0.0; extra == "amd"
47
56
 
48
57
  <a name="readme-top"></a>
49
58
 
@@ -55,7 +64,7 @@ Requires-Dist: seaborn; extra == "dev"
55
64
  <th style="padding: 10px;" colspan="2">Stable</th>
56
65
  <th style="padding: 10px;" colspan="2">Nightly</th>
57
66
  <th style="padding: 10px;">Discord</th>
58
- <th style="padding: 10px;">Gurubase (experimental)</th>
67
+ <th style="padding: 10px;">Build</th>
59
68
  </tr>
60
69
  <tr>
61
70
  <td style="padding: 10px;">
@@ -84,9 +93,16 @@ Requires-Dist: seaborn; extra == "dev"
84
93
  </a>
85
94
  </td>
86
95
  <td style="padding: 10px;">
87
- <a href="https://gurubase.io/g/liger-kernel">
88
- <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
89
- </a>
96
+ <div style="display: block;">
97
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
98
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
99
+ </a>
100
+ </div>
101
+ <div style="display: block;">
102
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
103
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
104
+ </a>
105
+ </div>
90
106
  </td>
91
107
  </tr>
92
108
  </table>
@@ -95,13 +111,14 @@ Requires-Dist: seaborn; extra == "dev"
95
111
 
96
112
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
97
113
 
98
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
114
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
99
115
 
100
116
  <details>
101
117
  <summary>Latest News 🔥</summary>
102
118
 
119
+ - [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
103
120
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
104
- - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
121
+ - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
105
122
  - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
106
123
  - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
107
124
  - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
@@ -111,6 +128,8 @@ Requires-Dist: seaborn; extra == "dev"
111
128
 
112
129
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
113
130
 
131
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
132
+
114
133
  ## Supercharge Your Model with Liger Kernel
115
134
 
116
135
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -128,12 +147,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
128
147
 
129
148
  ## Examples
130
149
 
131
-
132
150
  | **Use Case** | **Description** |
133
151
  |------------------------------------------------|---------------------------------------------------------------------------------------------------|
134
152
  | [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
135
153
  | [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
136
- | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | |
154
+ | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
155
+ | [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
156
+ | [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
137
157
 
138
158
  ## Key Features
139
159
 
@@ -146,7 +166,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
146
166
 
147
167
  ## Installation
148
168
 
149
- ### Dependencies
169
+ ### Dependencies
150
170
 
151
171
  #### CUDA
152
172
 
@@ -183,6 +203,8 @@ To install from source:
183
203
  git clone https://github.com/linkedin/Liger-Kernel.git
184
204
  cd Liger-Kernel
185
205
  pip install -e .
206
+ # or if installing on amd platform
207
+ pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
186
208
  # or if using transformers
187
209
  pip install -e .[transformers]
188
210
  ```
@@ -249,7 +271,7 @@ loss = loss_fn(model.weight, input, target)
249
271
  loss.backward()
250
272
  ```
251
273
 
252
- ## APIs
274
+ ## High-level APIs
253
275
 
254
276
  ### AutoModel
255
277
 
@@ -268,13 +290,17 @@ loss.backward()
268
290
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
269
291
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
270
292
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
271
- | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
293
+ | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
294
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
295
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
274
296
 
275
297
 
298
+ ## Low-level APIs
276
299
 
277
- ### Kernels
300
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
301
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
302
+
303
+ ### Model Kernels
278
304
 
279
305
  | **Kernel** | **API** |
280
306
  |---------------------------------|-------------------------------------------------------------|
@@ -284,39 +310,33 @@ loss.backward()
284
310
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
285
311
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
286
312
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
287
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
313
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
314
+
315
+
316
+ ### Alignment Kernels
317
+
318
+ | **Kernel** | **API** |
319
+ |---------------------------------|-------------------------------------------------------------|
320
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
321
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
322
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
323
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
324
+
325
+ ### Distillation Kernels
326
+
327
+ | **Kernel** | **API** |
328
+ |---------------------------------|-------------------------------------------------------------|
288
329
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
289
330
  | JSD | `liger_kernel.transformers.LigerJSD` |
290
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
291
-
292
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
293
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
294
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
295
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
296
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
297
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
298
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
299
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
300
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
301
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
302
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
303
- <!-- TODO: verify vocab sizes are accurate -->
304
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
305
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
306
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
307
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
308
-
331
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
309
332
 
310
333
  ### Experimental Kernels
311
334
 
312
335
  | **Kernel** | **API** |
313
336
  |---------------------------------|-------------------------------------------------------------|
314
337
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
315
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
338
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
316
339
 
317
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
318
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
319
- <!-- TODO: be more specific about batch size -->
320
340
 
321
341
  ## Contributing, Acknowledgements, and License
322
342
 
@@ -324,6 +344,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
324
344
  - [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
325
345
  - [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
326
346
 
347
+ ## Sponsorship and Collaboration
348
+
349
+ - [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
350
+ - [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
351
+ - [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
352
+ - [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
353
+ - [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
354
+ - [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
355
+ - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
356
+ - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
357
+
327
358
  ## Contact
328
359
 
329
360
  - For issues, create a Github ticket in this repository
@@ -335,7 +366,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
335
366
  Biblatex entry:
336
367
  ```bib
337
368
  @article{hsu2024ligerkernelefficienttriton,
338
- title={Liger Kernel: Efficient Triton Kernels for LLM Training},
369
+ title={Liger Kernel: Efficient Triton Kernels for LLM Training},
339
370
  author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
340
371
  year={2024},
341
372
  eprint={2410.10989},
@@ -349,15 +380,8 @@ Biblatex entry:
349
380
  ## Star History
350
381
  [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
351
382
 
352
- ## Contributors
353
-
354
- <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
355
- <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
356
- </a>
357
-
358
383
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
359
384
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
360
385
  ↑ Back to Top ↑
361
386
  </a>
362
387
  </p>
363
-
@@ -8,7 +8,7 @@
8
8
  <th style="padding: 10px;" colspan="2">Stable</th>
9
9
  <th style="padding: 10px;" colspan="2">Nightly</th>
10
10
  <th style="padding: 10px;">Discord</th>
11
- <th style="padding: 10px;">Gurubase (experimental)</th>
11
+ <th style="padding: 10px;">Build</th>
12
12
  </tr>
13
13
  <tr>
14
14
  <td style="padding: 10px;">
@@ -37,9 +37,16 @@
37
37
  </a>
38
38
  </td>
39
39
  <td style="padding: 10px;">
40
- <a href="https://gurubase.io/g/liger-kernel">
41
- <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
42
- </a>
40
+ <div style="display: block;">
41
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
42
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
43
+ </a>
44
+ </div>
45
+ <div style="display: block;">
46
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
47
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
48
+ </a>
49
+ </div>
43
50
  </td>
44
51
  </tr>
45
52
  </table>
@@ -48,13 +55,14 @@
48
55
 
49
56
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
50
57
 
51
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
58
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
52
59
 
53
60
  <details>
54
61
  <summary>Latest News 🔥</summary>
55
62
 
63
+ - [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
56
64
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
57
- - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
65
+ - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
58
66
  - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
59
67
  - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
60
68
  - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
@@ -64,6 +72,8 @@
64
72
 
65
73
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
66
74
 
75
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
76
+
67
77
  ## Supercharge Your Model with Liger Kernel
68
78
 
69
79
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -81,12 +91,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
81
91
 
82
92
  ## Examples
83
93
 
84
-
85
94
  | **Use Case** | **Description** |
86
95
  |------------------------------------------------|---------------------------------------------------------------------------------------------------|
87
96
  | [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
88
97
  | [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
89
- | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | |
98
+ | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
99
+ | [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
100
+ | [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
90
101
 
91
102
  ## Key Features
92
103
 
@@ -99,7 +110,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
99
110
 
100
111
  ## Installation
101
112
 
102
- ### Dependencies
113
+ ### Dependencies
103
114
 
104
115
  #### CUDA
105
116
 
@@ -136,6 +147,8 @@ To install from source:
136
147
  git clone https://github.com/linkedin/Liger-Kernel.git
137
148
  cd Liger-Kernel
138
149
  pip install -e .
150
+ # or if installing on amd platform
151
+ pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
139
152
  # or if using transformers
140
153
  pip install -e .[transformers]
141
154
  ```
@@ -202,7 +215,7 @@ loss = loss_fn(model.weight, input, target)
202
215
  loss.backward()
203
216
  ```
204
217
 
205
- ## APIs
218
+ ## High-level APIs
206
219
 
207
220
  ### AutoModel
208
221
 
@@ -221,13 +234,17 @@ loss.backward()
221
234
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
222
235
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
223
236
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
224
- | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
237
+ | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
225
238
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
226
239
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
227
240
 
228
241
 
242
+ ## Low-level APIs
229
243
 
230
- ### Kernels
244
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
245
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
246
+
247
+ ### Model Kernels
231
248
 
232
249
  | **Kernel** | **API** |
233
250
  |---------------------------------|-------------------------------------------------------------|
@@ -237,39 +254,33 @@ loss.backward()
237
254
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
238
255
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
239
256
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
240
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
257
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
258
+
259
+
260
+ ### Alignment Kernels
261
+
262
+ | **Kernel** | **API** |
263
+ |---------------------------------|-------------------------------------------------------------|
264
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
265
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
266
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
267
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
268
+
269
+ ### Distillation Kernels
270
+
271
+ | **Kernel** | **API** |
272
+ |---------------------------------|-------------------------------------------------------------|
241
273
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
242
274
  | JSD | `liger_kernel.transformers.LigerJSD` |
243
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
244
-
245
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
246
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
247
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
248
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
249
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
250
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
251
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
252
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
253
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
254
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
255
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
256
- <!-- TODO: verify vocab sizes are accurate -->
257
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
258
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
259
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
260
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
261
-
275
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
262
276
 
263
277
  ### Experimental Kernels
264
278
 
265
279
  | **Kernel** | **API** |
266
280
  |---------------------------------|-------------------------------------------------------------|
267
281
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
268
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
282
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
269
283
 
270
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
271
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
272
- <!-- TODO: be more specific about batch size -->
273
284
 
274
285
  ## Contributing, Acknowledgements, and License
275
286
 
@@ -277,6 +288,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
277
288
  - [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
278
289
  - [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
279
290
 
291
+ ## Sponsorship and Collaboration
292
+
293
+ - [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
294
+ - [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
295
+ - [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
296
+ - [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
297
+ - [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
298
+ - [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
299
+ - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
300
+ - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
301
+
280
302
  ## Contact
281
303
 
282
304
  - For issues, create a Github ticket in this repository
@@ -288,7 +310,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
288
310
  Biblatex entry:
289
311
  ```bib
290
312
  @article{hsu2024ligerkernelefficienttriton,
291
- title={Liger Kernel: Efficient Triton Kernels for LLM Training},
313
+ title={Liger Kernel: Efficient Triton Kernels for LLM Training},
292
314
  author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
293
315
  year={2024},
294
316
  eprint={2410.10989},
@@ -302,15 +324,8 @@ Biblatex entry:
302
324
  ## Star History
303
325
  [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
304
326
 
305
- ## Contributors
306
-
307
- <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
308
- <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
309
- </a>
310
-
311
327
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
312
328
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
313
329
  ↑ Back to Top ↑
314
330
  </a>
315
331
  </p>
316
-
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.4.2"
7
+ version = "0.5.1"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -14,11 +14,16 @@ dependencies = [
14
14
  "triton>=2.3.1",
15
15
  ]
16
16
 
17
+
17
18
  [project.optional-dependencies]
18
19
  transformers = [
19
20
  "transformers~=4.0"
20
21
  ]
21
22
 
23
+ trl = [
24
+ "trl>=0.11.0",
25
+ ]
26
+
22
27
  dev = [
23
28
  "transformers>=4.44.2",
24
29
  "matplotlib>=3.7.2",
@@ -26,11 +31,22 @@ dev = [
26
31
  "black>=24.4.2",
27
32
  "isort>=5.13.2",
28
33
  "pytest>=7.1.2",
34
+ "pytest-xdist",
35
+ "pytest-rerunfailures",
29
36
  "datasets>=2.19.2",
30
37
  "torchvision>=0.16.2",
31
38
  "seaborn",
32
39
  ]
33
40
 
41
+ amd = [
42
+ "torch>=2.6.0.dev",
43
+ "setuptools-scm>=8",
44
+ "torchvision>=0.20.0.dev",
45
+ "triton>=3.0.0",
46
+ ]
47
+
48
+
49
+
34
50
  [tool.setuptools.packages.find]
35
51
  where = ["src"]
36
52
  include = ["liger_kernel", "liger_kernel.*"]
@@ -0,0 +1,4 @@
1
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -0,0 +1,107 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
+ """
14
+ Paper: https://arxiv.org/pdf/2401.08417
15
+
16
+ Formula:
17
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18
+
19
+ Where:
20
+ - π_θ(y|x): Policy (model) probability
21
+ - y_w: Chosen sequence
22
+ - y_l: Rejected sequence
23
+ - σ: Sigmoid function
24
+ - β: Temperature parameter
25
+ - E: Expected value over the dataset D
26
+ - D: Dataset of preferences
27
+
28
+ Args:
29
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
30
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the CPO loss
33
+ """
34
+ logits = beta * (chosen_logps - rejected_logps)
35
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
36
+ return loss
37
+
38
+ @staticmethod
39
+ def forward(
40
+ ctx,
41
+ _input,
42
+ weight,
43
+ target,
44
+ bias=None,
45
+ ignore_index=-100,
46
+ beta=0.1,
47
+ alpha=1.0,
48
+ compute_nll_loss=True,
49
+ compiled=True,
50
+ ):
51
+ return LigerFusedLinearPreferenceBase.forward(
52
+ ctx,
53
+ _input,
54
+ weight,
55
+ target,
56
+ bias,
57
+ loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
58
+ ignore_index=ignore_index,
59
+ alpha=alpha,
60
+ beta=beta,
61
+ compute_nll_loss=compute_nll_loss,
62
+ compiled=compiled,
63
+ )
64
+
65
+ @staticmethod
66
+ def backward(ctx, *grad_output):
67
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
68
+ return *grads, None, None, None, None, None
69
+
70
+
71
+ class LigerFusedLinearCPOLoss(torch.nn.Module):
72
+ """
73
+ Fused linear layer with CPO loss.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ ignore_index: int = -100,
79
+ beta: float = 0.1,
80
+ alpha: float = 1.0,
81
+ compute_nll_loss: bool = True,
82
+ compiled: bool = True,
83
+ ):
84
+ """
85
+ Args:
86
+ ignore_index (int): Index to ignore in the loss.
87
+ beta (float): Weight for the odds ratio loss.
88
+ """
89
+ super().__init__()
90
+ self.ignore_index = ignore_index
91
+ self.beta = beta
92
+ self.alpha = alpha
93
+ self.compute_nll_loss = compute_nll_loss
94
+ self.compiled = compiled
95
+
96
+ def forward(self, lin_weight, _input, target, bias=None):
97
+ return LigerFusedLinearCPOFunction.apply(
98
+ _input,
99
+ lin_weight,
100
+ target,
101
+ bias,
102
+ self.ignore_index,
103
+ self.beta,
104
+ self.alpha,
105
+ self.compute_nll_loss,
106
+ self.compiled,
107
+ )