liger-kernel 0.0.0__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel-0.1.0/LICENSE +23 -0
- liger_kernel-0.1.0/NOTICE +4 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/PKG-INFO +3 -1
- liger_kernel-0.1.0/README.md +205 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/setup.py +1 -1
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/cross_entropy.py +4 -33
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/fused_linear_cross_entropy.py +6 -6
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/geglu.py +14 -3
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/rms_norm.py +40 -22
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/swiglu.py +16 -16
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/utils.py +12 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/__init__.py +1 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/model/llama.py +3 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/monkey_patch.py +35 -8
- liger_kernel-0.1.0/src/liger_kernel/transformers/trainer_integration.py +45 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/triton/monkey_patch.py +0 -2
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel.egg-info/PKG-INFO +3 -1
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel.egg-info/SOURCES.txt +4 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/setup.cfg +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
BSD 2-CLAUSE LICENSE
|
|
2
|
+
Copyright 2024 LinkedIn Corporation
|
|
3
|
+
All Rights Reserved.
|
|
4
|
+
Redistribution and use in source and binary forms, with or
|
|
5
|
+
without modification, are permitted provided that the following
|
|
6
|
+
conditions are met:
|
|
7
|
+
1. Redistributions of source code must retain the above copyright
|
|
8
|
+
notice, this list of conditions and the following disclaimer.
|
|
9
|
+
2. Redistributions in binary form must reproduce the above
|
|
10
|
+
copyright notice, this list of conditions and the following
|
|
11
|
+
disclaimer in the documentation and/or other materials provided
|
|
12
|
+
with the distribution.
|
|
13
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
14
|
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
15
|
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
16
|
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
17
|
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
18
|
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
19
|
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
20
|
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
21
|
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
22
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
23
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
|
2
|
+
|
|
3
|
+
[](https://pepy.tech/project/liger-kernel) [](https://badge.fury.io/py/liger-kernel) [](https://badge.fury.io/py/liger-kernel-nightly)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing)
|
|
7
|
+
|
|
8
|
+
**Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. The kernel works out of the box with [flash attention](https://github.com/Dao-AILab/flash-attention), PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
## Supercharge Your Model with Liger Kernel
|
|
13
|
+
|
|
14
|
+

|
|
15
|
+
|
|
16
|
+
Gain +20% throughput and reduce memory usage by 60%. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes.
|
|
17
|
+
|
|
18
|
+
| Speed Up | Memory Reduction |
|
|
19
|
+
|--------------------------|-------------------------|
|
|
20
|
+
|  |  |
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
|
|
24
|
+
> - Hugging Face models start to OOM at a 4K context length, whereas Liger Kernel scales up to 16K.
|
|
25
|
+
|
|
26
|
+
## Examples
|
|
27
|
+
|
|
28
|
+
### Basic
|
|
29
|
+
|
|
30
|
+
| **Example** | **Description** | **Lightning Studio** |
|
|
31
|
+
|------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
|
|
32
|
+
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train llama3 8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP | TBA |
|
|
33
|
+
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 | TBA |
|
|
34
|
+
|
|
35
|
+
### Advanced
|
|
36
|
+
|
|
37
|
+
| **Example** | **Description** | **Lightning Studio** |
|
|
38
|
+
|------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
|
|
39
|
+
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | TBA |
|
|
40
|
+
|
|
41
|
+
## Key Features
|
|
42
|
+
|
|
43
|
+
- **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our kernels.
|
|
44
|
+
- **Time- and memory-efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques.
|
|
45
|
+
- **Exact:** Exact kernels—no approximations. Both forward and backward are implemented with rigorous unit and convergence testing to ensure accuracy.
|
|
46
|
+
- **Lightweight:** The kernels have minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
47
|
+
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP and DeepSpeed).
|
|
48
|
+
|
|
49
|
+
## Target Audiences
|
|
50
|
+
|
|
51
|
+
- **Researchers**: Looking to compose models using efficient and reliable kernels for frontier experiments.
|
|
52
|
+
- **ML Practitioners**: Focused on maximizing GPU training efficiency with optimal, high-performance kernels.
|
|
53
|
+
- **Curious Novices**: Eager to learn how to write reliable Triton kernels to enhance training efficiency.
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
### Dependencies
|
|
59
|
+
|
|
60
|
+
- `torch >= 2.1.2`
|
|
61
|
+
- `triton >= 2.3.0`
|
|
62
|
+
- `transformers >= 4.40.1`
|
|
63
|
+
|
|
64
|
+
To install the stable version:
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
$ pip install liger-kernel
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
To install the nightly version:
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
$ pip install liger-kernel-nightly
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## Getting Started
|
|
77
|
+
|
|
78
|
+
### 1. Patch Existing Hugging Face Models
|
|
79
|
+
|
|
80
|
+
Using [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import transformers
|
|
84
|
+
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
|
85
|
+
|
|
86
|
+
model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
|
|
87
|
+
|
|
88
|
+
# By adding this line, it automatically monkey patches the model with the optimized kernels
|
|
89
|
+
apply_liger_kernel_to_llama()
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### 2. Compose Your Own Model
|
|
93
|
+
|
|
94
|
+
You can take individual [kernels](#kernels) to compose your models.
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
|
|
98
|
+
import torch.nn as nn
|
|
99
|
+
import torch
|
|
100
|
+
|
|
101
|
+
model = nn.Linear(128, 256).to("cuda")
|
|
102
|
+
|
|
103
|
+
# LigerFusedLinearCrossEntropyLoss fuses linear and cross entropy layers together and performs chunk-by-chunk computation to reduce memory
|
|
104
|
+
loss_fn = LigerFusedLinearCrossEntropyLoss()
|
|
105
|
+
|
|
106
|
+
input = torch.randn(4, 128, requires_grad=True, device="cuda")
|
|
107
|
+
target = torch.empty(4, dtype=torch.long, device="cuda").random_(256)
|
|
108
|
+
|
|
109
|
+
loss = loss_fn(model.weight, input, target)
|
|
110
|
+
loss.backward()
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
## Structure
|
|
115
|
+
|
|
116
|
+
### Source Code
|
|
117
|
+
|
|
118
|
+
- `ops/`: Core Triton operations.
|
|
119
|
+
- `transformers/`: PyTorch `nn.Module` implementations built on Triton operations, compliant with the `transformers` API.
|
|
120
|
+
|
|
121
|
+
### Tests
|
|
122
|
+
|
|
123
|
+
- `transformers/`: Correctness tests for the Triton-based layers.
|
|
124
|
+
- `convergence/`: Patches Hugging Face models with all kernels, runs multiple iterations, and compares weights, logits, and loss layer by layer.
|
|
125
|
+
|
|
126
|
+
### Benchmark
|
|
127
|
+
|
|
128
|
+
- `benchmark/`: Execution time and memory benchmarks compared to Hugging Face layers.
|
|
129
|
+
|
|
130
|
+
## APIs
|
|
131
|
+
|
|
132
|
+
### Patching
|
|
133
|
+
|
|
134
|
+
| **Model** | **API** | **Supported Operations** |
|
|
135
|
+
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
|
|
136
|
+
| LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
137
|
+
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
138
|
+
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
139
|
+
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
|
|
140
|
+
|
|
141
|
+
### Kernels
|
|
142
|
+
|
|
143
|
+
| **Kernel** | **API** |
|
|
144
|
+
|---------------------------------|-------------------------------------------------------------|
|
|
145
|
+
| RMSNorm | `liger_kernel.transformers.LigerRMSNorm` |
|
|
146
|
+
| RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` |
|
|
147
|
+
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
148
|
+
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
149
|
+
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
150
|
+
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
151
|
+
|
|
152
|
+
- **RMSNorm**: RMSNorm, which normalizes tensor activations using their root mean square, is accelerated by fusing the normalization and scaling steps into a single triton kernel, achieved ~3X speedup with ~3X peak memory reduction. [RMSNorm Paper](https://arxiv.org/pdf/1910.07467)
|
|
153
|
+
- **RoPE**: Fused the operations of query and key embedding rotary into a single kernel with inplace replacement, achieved ~3X speedup with ~3X peak memory reduction. [RoPE Paper](https://arxiv.org/pdf/2104.09864)
|
|
154
|
+
- **SwiGLU**: Leveraging the fused triton kernel for the elementwise transformation in $$SwiGLU_{\beta=1}$$ ($$\sigma(A) \odot B$$) with inplace replacement, achieved parity speed with ~1.5X peak memory reduction. [SwiGLU Paper](https://arxiv.org/pdf/2002.05202)
|
|
155
|
+
- **GeGLU**: Leveraging the fused triton kernel for the elementwise transformation in GeGLU with [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) and inplace replacement, achieved parity speed with ~1.5X peak memory reduction. [GeGLU paper](https://arxiv.org/pdf/2002.05202)
|
|
156
|
+
- **CrossEntropy**: Computes both loss and the gradient in the forward path with inplace replacement of input to reduce the peak memory (avoid the materialization of both input logits and gradient), achieved >2X speedup and >4X memory reduction for common vocab sizes. [PyTorch CrossEntropyLoss Documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
|
|
157
|
+
- **FusedLinearCrossEntropy**: Further improves upon the basic Liger Cross Entropy kernel on reducing the peak memory usage by fusing the model last output head layer with the CE loss and chunking the input for block-wise loss and gradient calculation, inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy), achieved >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocab size model** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
> * Reported speedups and memory reductions are compared with Llama3 8B Hugging Face layer implementations with 4k default hidden size and 4k sequence length for single forward and backward pass on single NVIDIA A100 80G GPU with small batch sizes. Liger kernels exhibits more efficient scaling to larger batch sizes of tokens. See [Benchmark](./benchmark) folder for details.
|
|
161
|
+
|
|
162
|
+
## Note on ML Compiler
|
|
163
|
+
|
|
164
|
+
### 1. Torch Compile
|
|
165
|
+
|
|
166
|
+
Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile. In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
|
|
167
|
+
|
|
168
|
+
| Configuration | Throughput (tokens/sec) | Memory Reserved (MB) |
|
|
169
|
+
|--------------------------------|----------------------------|-------------------------|
|
|
170
|
+
| Torch Compile | 3780 | 66358 |
|
|
171
|
+
| Torch Compile + Liger Kernel | 3702 | 31000 |
|
|
172
|
+
|
|
173
|
+
> **Note:**
|
|
174
|
+
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
|
|
175
|
+
> 2. Tested on torch `2.5.0.dev20240731+cu118`
|
|
176
|
+
|
|
177
|
+
### 2. Lightning Thunder
|
|
178
|
+
|
|
179
|
+
*WIP*
|
|
180
|
+
|
|
181
|
+
## Contributing
|
|
182
|
+
|
|
183
|
+
[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
|
|
184
|
+
|
|
185
|
+
## Acknowledgement
|
|
186
|
+
|
|
187
|
+
- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
|
|
188
|
+
- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
## License
|
|
192
|
+
|
|
193
|
+
[BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
|
|
194
|
+
|
|
195
|
+
## Cite this work
|
|
196
|
+
|
|
197
|
+
Biblatex entry:
|
|
198
|
+
```bib
|
|
199
|
+
@software{liger2024,
|
|
200
|
+
title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
|
|
201
|
+
author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
|
|
202
|
+
url = {https://github.com/linkedin/Liger-Kernel},
|
|
203
|
+
year = {2024}
|
|
204
|
+
}
|
|
205
|
+
```
|
|
@@ -17,7 +17,7 @@ def liger_cross_entropy_kernel(
|
|
|
17
17
|
BLOCK_SIZE: tl.constexpr,
|
|
18
18
|
):
|
|
19
19
|
"""
|
|
20
|
-
This kernel computes both cross entropy loss and the gradient of the
|
|
20
|
+
This kernel computes both cross entropy loss and the gradient of the input.
|
|
21
21
|
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
|
22
22
|
|
|
23
23
|
Parameters:
|
|
@@ -34,7 +34,7 @@ def liger_cross_entropy_kernel(
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
# https://github.com/triton-lang/triton/issues/1058
|
|
37
|
-
#
|
|
37
|
+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
|
|
38
38
|
program_id = tl.program_id(0).to(tl.int64)
|
|
39
39
|
|
|
40
40
|
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
|
@@ -90,13 +90,7 @@ def liger_cross_entropy_kernel(
|
|
|
90
90
|
tl.debug_barrier()
|
|
91
91
|
|
|
92
92
|
# 5. Calculate the loss
|
|
93
|
-
# Old Approach: Problematic LogSoftmax
|
|
94
|
-
# min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
|
|
95
|
-
# This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
|
|
96
|
-
# loss = -tl.log(X_y * n_non_ignore)
|
|
97
93
|
|
|
98
|
-
# New Approach: Safe LogSoftmax
|
|
99
|
-
# Therefore, we propose to use safe logsoftmax by reordering the formula.
|
|
100
94
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
101
95
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
102
96
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
@@ -114,7 +108,7 @@ def liger_cross_entropy_kernel(
|
|
|
114
108
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
115
109
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
116
110
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
117
|
-
MAX_FUSED_SIZE = 65536 // 2 #
|
|
111
|
+
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
118
112
|
|
|
119
113
|
|
|
120
114
|
@triton.jit
|
|
@@ -184,28 +178,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
184
178
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
185
179
|
|
|
186
180
|
# ensure _input and target are contiguous in the last dimension
|
|
187
|
-
# there are examples that are NOT contiguous overall but contiguous in the last dimension
|
|
188
|
-
####################################################################
|
|
189
|
-
# tensor = torch.arange(1, 21).reshape(5, -1)
|
|
190
|
-
# print(tensor)
|
|
191
|
-
# tensor([[ 1, 2, 3, 4],
|
|
192
|
-
# [ 5, 6, 7, 8],
|
|
193
|
-
# [ 9, 10, 11, 12],
|
|
194
|
-
# [13, 14, 15, 16],
|
|
195
|
-
# [17, 18, 19, 20]])
|
|
196
|
-
# print(tensor.is_contiguous())
|
|
197
|
-
# True
|
|
198
|
-
# slice = tensor[::2, :]
|
|
199
|
-
# print(slice)
|
|
200
|
-
# tensor([[ 1, 2, 3, 4],
|
|
201
|
-
# [ 9, 10, 11, 12],
|
|
202
|
-
# [17, 18, 19, 20]])
|
|
203
|
-
# print(slice.is_contiguous())
|
|
204
|
-
# False
|
|
205
|
-
# print(slice.stride())
|
|
206
|
-
# (8, 1)
|
|
207
|
-
# slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
|
|
208
|
-
####################################################################
|
|
209
181
|
if _input.stride(-1) != 1:
|
|
210
182
|
_input = _input.contiguous()
|
|
211
183
|
if target.stride(-1) != 1:
|
|
@@ -252,10 +224,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
252
224
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
253
225
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
254
226
|
pass
|
|
227
|
+
|
|
255
228
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
256
229
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
257
|
-
# Although the Brew trainer should only perform backward once, it encounters this issue.
|
|
258
|
-
# https://github.com/triton-lang/triton/issues/4004
|
|
259
230
|
else:
|
|
260
231
|
BT, V = _input.shape
|
|
261
232
|
n_rows = BT
|
{liger_kernel-0.0.0 → liger_kernel-0.1.0}/src/liger_kernel/ops/fused_linear_cross_entropy.py
RENAMED
|
@@ -1,8 +1,3 @@
|
|
|
1
|
-
"""Fusing the last linear layer with cross-entropy loss
|
|
2
|
-
|
|
3
|
-
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
import triton
|
|
8
3
|
|
|
@@ -11,13 +6,16 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
|
|
|
11
6
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
12
7
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
13
8
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
14
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
9
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
18
13
|
@staticmethod
|
|
19
14
|
def forward(ctx, _input, linear, target, ignore_index):
|
|
20
15
|
"""
|
|
16
|
+
Fusing the last linear layer with cross-entropy loss
|
|
17
|
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
18
|
+
|
|
21
19
|
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
|
22
20
|
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
|
23
21
|
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
|
@@ -54,6 +52,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
54
52
|
|
|
55
53
|
grad_linear = torch.zeros_like(linear, device=device)
|
|
56
54
|
grad_input = torch.zeros_like(_input, device=device)
|
|
55
|
+
|
|
56
|
+
# we use fp32 for loss accumulator
|
|
57
57
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
58
58
|
|
|
59
59
|
total_n_non_ignore = (target != ignore_index).sum().item()
|
|
@@ -1,8 +1,19 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
import triton.language as tl
|
|
4
6
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
calculate_settings,
|
|
9
|
+
compare_version,
|
|
10
|
+
ensure_contiguous,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
from triton.language.extra.libdevice import tanh
|
|
15
|
+
else:
|
|
16
|
+
from triton.language.math import tanh
|
|
6
17
|
|
|
7
18
|
|
|
8
19
|
@triton.jit
|
|
@@ -26,7 +37,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
26
37
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
27
38
|
a_cubed = a_row * a_row * a_row
|
|
28
39
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
29
|
-
tanh_result =
|
|
40
|
+
tanh_result = tanh(tanh_arg)
|
|
30
41
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
31
42
|
c_row = geglu_a * b_row
|
|
32
43
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
@@ -54,7 +65,7 @@ def _geglu_tanh_backward_kernel(
|
|
|
54
65
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
55
66
|
a_cubed = a_row * a_row * a_row
|
|
56
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
57
|
-
tanh_result =
|
|
68
|
+
tanh_result = tanh(tanh_arg)
|
|
58
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
59
70
|
|
|
60
71
|
db_row = dc_row * geglu_a
|
|
@@ -20,9 +20,12 @@ def _rms_norm_forward(
|
|
|
20
20
|
BLOCK_SIZE: tl.constexpr,
|
|
21
21
|
):
|
|
22
22
|
"""
|
|
23
|
+
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
|
|
24
|
+
|
|
23
25
|
Reference:
|
|
24
26
|
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
25
27
|
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
28
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
31
|
row_idx = tl.program_id(0)
|
|
@@ -36,16 +39,17 @@ def _rms_norm_forward(
|
|
|
36
39
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
37
40
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
38
41
|
|
|
39
|
-
|
|
40
|
-
|
|
42
|
+
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
43
|
+
inv_rms = tl.math.rsqrt(mean_square + eps)
|
|
41
44
|
|
|
42
|
-
#
|
|
43
|
-
|
|
45
|
+
# We can save time by caching rms with minimal memory overhead
|
|
46
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
47
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
48
|
+
tl.store(r_ptr, inv_rms)
|
|
44
49
|
|
|
45
|
-
|
|
50
|
+
Y_row = X_row * inv_rms * W_row
|
|
46
51
|
|
|
47
|
-
|
|
48
|
-
tl.store(Y_ptr + col_offsets, output, mask=mask)
|
|
52
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
49
53
|
|
|
50
54
|
|
|
51
55
|
@triton.jit
|
|
@@ -65,9 +69,10 @@ def _rms_norm_backward(
|
|
|
65
69
|
BLOCK_SIZE: tl.constexpr,
|
|
66
70
|
):
|
|
67
71
|
"""
|
|
68
|
-
dx = (1 /
|
|
69
|
-
dw = sum(dy * (x /
|
|
72
|
+
dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
73
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
70
74
|
"""
|
|
75
|
+
|
|
71
76
|
row_idx = tl.program_id(0)
|
|
72
77
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
73
78
|
mask = col_offsets < n_cols
|
|
@@ -81,34 +86,42 @@ def _rms_norm_backward(
|
|
|
81
86
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
82
87
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
83
88
|
|
|
84
|
-
# Get
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
normed = X_row * inv_var
|
|
89
|
+
# Get cached rms
|
|
90
|
+
inv_rms_row = tl.load(r_ptr)
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
dX_row = (inv_rms_row) * (
|
|
93
|
+
dY_row * W_row
|
|
94
|
+
- (1 / n_cols)
|
|
95
|
+
* inv_rms_row
|
|
96
|
+
* inv_rms_row
|
|
97
|
+
* tl.sum(dY_row * W_row * X_row, axis=0)
|
|
98
|
+
* X_row
|
|
99
|
+
)
|
|
100
|
+
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
95
101
|
|
|
96
102
|
# calculate the gradient of W
|
|
97
|
-
|
|
103
|
+
dW_row = dY_row * X_row * inv_rms_row
|
|
104
|
+
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
|
|
98
105
|
|
|
99
106
|
|
|
100
107
|
class LigerRMSNormFunction(torch.autograd.Function):
|
|
101
108
|
@staticmethod
|
|
102
109
|
@ensure_contiguous
|
|
103
110
|
def forward(ctx, X, W, eps):
|
|
111
|
+
"""
|
|
112
|
+
X: (B, T, H) or (BxT, H)
|
|
113
|
+
W: (H,)
|
|
114
|
+
"""
|
|
115
|
+
|
|
104
116
|
shape = X.shape
|
|
105
117
|
dim = shape[-1]
|
|
106
118
|
X = X.view(-1, dim)
|
|
107
119
|
n_rows, n_cols = X.shape
|
|
108
120
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
109
121
|
|
|
110
|
-
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=
|
|
111
|
-
r
|
|
122
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
123
|
+
# r is to cache (1/rms) for each row
|
|
124
|
+
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
112
125
|
|
|
113
126
|
# Check constraints.
|
|
114
127
|
assert (
|
|
@@ -139,6 +152,10 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
139
152
|
@staticmethod
|
|
140
153
|
@ensure_contiguous
|
|
141
154
|
def backward(ctx, dY):
|
|
155
|
+
"""
|
|
156
|
+
Y: (B, T, H) or (BxT, H)
|
|
157
|
+
"""
|
|
158
|
+
|
|
142
159
|
shape = dY.shape
|
|
143
160
|
dim = shape[-1]
|
|
144
161
|
dY = dY.view(-1, dim)
|
|
@@ -146,6 +163,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
146
163
|
n_rows, n_cols = dY.shape
|
|
147
164
|
dW = torch.zeros_like(X)
|
|
148
165
|
|
|
166
|
+
# Here we use dY to store the value of dX to save memory
|
|
149
167
|
_rms_norm_backward[(n_rows,)](
|
|
150
168
|
dY,
|
|
151
169
|
dY.stride(0),
|
|
@@ -12,43 +12,43 @@ def silu(x):
|
|
|
12
12
|
|
|
13
13
|
@triton.jit
|
|
14
14
|
def _swiglu_forward_kernel(
|
|
15
|
-
|
|
15
|
+
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
16
|
):
|
|
17
17
|
program_id = tl.program_id(0)
|
|
18
18
|
|
|
19
19
|
# locate start index
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
a_ptr += program_id * stride
|
|
21
|
+
b_ptr += program_id * stride
|
|
22
|
+
c_ptr += program_id * stride
|
|
23
23
|
|
|
24
24
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
25
25
|
mask = col_offsets < n_cols
|
|
26
26
|
|
|
27
27
|
# sigmoid requires type float32
|
|
28
|
-
a_row = tl.load(
|
|
29
|
-
b_row = tl.load(
|
|
28
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
29
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
30
30
|
c_row = silu(a_row) * b_row
|
|
31
|
-
tl.store(
|
|
31
|
+
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
@triton.jit
|
|
35
35
|
def _swiglu_backward_kernel(
|
|
36
|
-
|
|
36
|
+
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
37
|
):
|
|
38
38
|
program_id = tl.program_id(0)
|
|
39
39
|
|
|
40
40
|
# locate start index
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
dc_ptr += program_id * stride
|
|
42
|
+
a_ptr += program_id * stride
|
|
43
|
+
b_ptr += program_id * stride
|
|
44
44
|
|
|
45
45
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
46
46
|
mask = col_offsets < n_cols
|
|
47
47
|
|
|
48
|
-
dc_row = tl.load(
|
|
48
|
+
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
|
49
49
|
# sigmoid requires type float32
|
|
50
|
-
a_row = tl.load(
|
|
51
|
-
b_row = tl.load(
|
|
50
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
51
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
52
52
|
|
|
53
53
|
# recomputation to save memory
|
|
54
54
|
sig_a = tl.sigmoid(a_row)
|
|
@@ -56,8 +56,8 @@ def _swiglu_backward_kernel(
|
|
|
56
56
|
db_row = dc_row * silu_a
|
|
57
57
|
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
|
58
58
|
|
|
59
|
-
tl.store(
|
|
60
|
-
tl.store(
|
|
59
|
+
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
|
60
|
+
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Callable
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
import triton
|
|
7
|
+
from packaging.version import Version
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
def ensure_contiguous(fn):
|
|
@@ -36,3 +39,12 @@ def calculate_settings(n):
|
|
|
36
39
|
elif BLOCK_SIZE >= 2048:
|
|
37
40
|
num_warps = 8
|
|
38
41
|
return BLOCK_SIZE, num_warps
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compare_version(package: str, operator: Callable, target: str):
|
|
45
|
+
try:
|
|
46
|
+
pkg = importlib.import_module(package)
|
|
47
|
+
except ImportError:
|
|
48
|
+
return False
|
|
49
|
+
pkg_version = Version(pkg.__version__)
|
|
50
|
+
return operator(pkg_version, Version(target))
|
|
@@ -37,6 +37,9 @@ def lce_forward(
|
|
|
37
37
|
cache_position: Optional[torch.LongTensor] = None,
|
|
38
38
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
39
39
|
r"""
|
|
40
|
+
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
|
|
41
|
+
|
|
42
|
+
|
|
40
43
|
Args:
|
|
41
44
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
42
45
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -1,27 +1,26 @@
|
|
|
1
1
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
2
|
+
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
2
3
|
from liger_kernel.transformers.model.llama import lce_forward
|
|
3
4
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
4
5
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
5
6
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
# TODO: probably rename utils.py as hf_patcher.py to be more descriptive
|
|
9
9
|
def apply_liger_kernel_to_llama(
|
|
10
10
|
rope: bool = True,
|
|
11
|
-
cross_entropy: bool =
|
|
12
|
-
fused_linear_cross_entropy: bool =
|
|
11
|
+
cross_entropy: bool = False,
|
|
12
|
+
fused_linear_cross_entropy: bool = True,
|
|
13
13
|
rms_norm: bool = True,
|
|
14
14
|
swiglu: bool = True,
|
|
15
15
|
) -> None:
|
|
16
16
|
"""
|
|
17
17
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
18
|
-
to make GPU go burrr.
|
|
19
18
|
|
|
20
19
|
Args:
|
|
21
20
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
22
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
21
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
23
22
|
fused_linear_cross_entropy (bool):
|
|
24
|
-
Whether to apply Liger's fused lienar cross entropy loss. Default is
|
|
23
|
+
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
|
|
25
24
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
26
25
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
27
26
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
@@ -54,7 +53,6 @@ def apply_liger_kernel_to_mistral(
|
|
|
54
53
|
) -> None:
|
|
55
54
|
"""
|
|
56
55
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
57
|
-
to make GPU go burrr.
|
|
58
56
|
|
|
59
57
|
Args:
|
|
60
58
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -83,12 +81,12 @@ def apply_liger_kernel_to_mixtral(
|
|
|
83
81
|
) -> None:
|
|
84
82
|
"""
|
|
85
83
|
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
|
|
86
|
-
to make GPU go burrr.
|
|
87
84
|
|
|
88
85
|
Args:
|
|
89
86
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
90
87
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
91
88
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
89
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
92
90
|
"""
|
|
93
91
|
|
|
94
92
|
from transformers.models.mixtral import modeling_mixtral
|
|
@@ -101,3 +99,32 @@ def apply_liger_kernel_to_mixtral(
|
|
|
101
99
|
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
102
100
|
if swiglu:
|
|
103
101
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def apply_liger_kernel_to_gemma(
|
|
105
|
+
rope: bool = True,
|
|
106
|
+
cross_entropy: bool = True,
|
|
107
|
+
rms_norm: bool = True,
|
|
108
|
+
geglu: bool = True,
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
|
|
112
|
+
to make GPU go burrr.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
116
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
117
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
118
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
119
|
+
"""
|
|
120
|
+
# TODO(yundai424): add convergence test for gemma
|
|
121
|
+
from transformers.models.gemma import modeling_gemma
|
|
122
|
+
|
|
123
|
+
if rope:
|
|
124
|
+
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
125
|
+
if rms_norm:
|
|
126
|
+
modeling_gemma.GemmaRMSNorm = LigerRMSNorm
|
|
127
|
+
if cross_entropy:
|
|
128
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
129
|
+
if geglu:
|
|
130
|
+
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from liger_kernel.transformers.monkey_patch import (
|
|
4
|
+
apply_liger_kernel_to_gemma,
|
|
5
|
+
apply_liger_kernel_to_llama,
|
|
6
|
+
apply_liger_kernel_to_mistral,
|
|
7
|
+
apply_liger_kernel_to_mixtral,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
13
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
14
|
+
"gemma": apply_liger_kernel_to_gemma,
|
|
15
|
+
"llama": apply_liger_kernel_to_llama,
|
|
16
|
+
"mistral": apply_liger_kernel_to_mistral,
|
|
17
|
+
"mixtral": apply_liger_kernel_to_mixtral,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Applies Liger kernels based on the specified model type. The custom
|
|
24
|
+
kernels for the specified model type will be applied with the provided
|
|
25
|
+
keyword arguments, otherwise the default configuration will be used.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
29
|
+
and specified in the model's config.json
|
|
30
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
if not model_type:
|
|
34
|
+
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
38
|
+
logger.info(
|
|
39
|
+
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
40
|
+
)
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
logger.info(f"Applying Liger kernels for model type: {model_type}.")
|
|
44
|
+
# Apply the default combination of liger kernels available for the model
|
|
45
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN[model_type](**kwargs)
|
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import random
|
|
3
3
|
|
|
4
|
-
from overrides import override
|
|
5
4
|
from triton.runtime.cache import FileCacheManager
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class LigerTritonFileCacheManager(FileCacheManager):
|
|
9
|
-
@override
|
|
10
8
|
def put(self, data, filename, binary=True) -> str:
|
|
11
9
|
if not self.cache_dir:
|
|
12
10
|
raise RuntimeError("Could not create or locate cache dir")
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
NOTICE
|
|
3
|
+
README.md
|
|
1
4
|
setup.py
|
|
2
5
|
src/liger_kernel.egg-info/PKG-INFO
|
|
3
6
|
src/liger_kernel.egg-info/SOURCES.txt
|
|
@@ -20,6 +23,7 @@ src/liger_kernel/transformers/monkey_patch.py
|
|
|
20
23
|
src/liger_kernel/transformers/rms_norm.py
|
|
21
24
|
src/liger_kernel/transformers/rope.py
|
|
22
25
|
src/liger_kernel/transformers/swiglu.py
|
|
26
|
+
src/liger_kernel/transformers/trainer_integration.py
|
|
23
27
|
src/liger_kernel/transformers/model/__init__.py
|
|
24
28
|
src/liger_kernel/transformers/model/llama.py
|
|
25
29
|
src/liger_kernel/triton/__init__.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|