liger-kernel 0.0.0__tar.gz → 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel-0.0.1/LICENSE +23 -0
- liger_kernel-0.0.1/NOTICE +4 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/PKG-INFO +3 -1
- liger_kernel-0.0.1/README.md +206 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/setup.py +1 -1
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/cross_entropy.py +4 -33
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +6 -6
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/geglu.py +14 -3
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/rms_norm.py +2 -2
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/utils.py +12 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/model/llama.py +3 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/monkey_patch.py +5 -8
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/PKG-INFO +3 -1
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/SOURCES.txt +3 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/setup.cfg +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/requires.txt +0 -0
- {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
BSD 2-CLAUSE LICENSE
|
|
2
|
+
Copyright 2024 LinkedIn Corporation
|
|
3
|
+
All Rights Reserved.
|
|
4
|
+
Redistribution and use in source and binary forms, with or
|
|
5
|
+
without modification, are permitted provided that the following
|
|
6
|
+
conditions are met:
|
|
7
|
+
1. Redistributions of source code must retain the above copyright
|
|
8
|
+
notice, this list of conditions and the following disclaimer.
|
|
9
|
+
2. Redistributions in binary form must reproduce the above
|
|
10
|
+
copyright notice, this list of conditions and the following
|
|
11
|
+
disclaimer in the documentation and/or other materials provided
|
|
12
|
+
with the distribution.
|
|
13
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
14
|
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
15
|
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
16
|
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
17
|
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
18
|
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
19
|
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
20
|
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
21
|
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
22
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
23
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Liger Kernel
|
|
2
|
+
|
|
3
|
+
[](https://pepy.tech/project/liger-kernel) [](https://badge.fury.io/py/liger-kernel) [](https://badge.fury.io/py/liger-kernel-nightly)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
[Installation](#installation) | [Getting Started](#getting-started) | [Structure](#structure) | [APIs](#apis) | [Contributing](#contributing)
|
|
7
|
+
|
|
8
|
+
**Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. The kernel works out of the box with [flash attention](https://github.com/Dao-AILab/flash-attention), PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
### Basic
|
|
12
|
+
|
|
13
|
+
| **Example** | **Description** | **Lightning Studio** |
|
|
14
|
+
|------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
|
|
15
|
+
| **[Hugging Face Trainer](#liger-kernel)** | Increase 20% throughput and reduce memory usage by 60% with LLaMA 3 8B on the MMLU dataset using 8 A100s | TBA |
|
|
16
|
+
| **[Lightning Trainer](#liger-kernel)** | Increase 15% throughput and reduce memory usage by 40% with LLaMA 3 8B on the Alpaca dataset using 4 A100s | TBA |
|
|
17
|
+
|
|
18
|
+
### Advanced
|
|
19
|
+
|
|
20
|
+
| **Example** | **Description** | **Lightning Studio** |
|
|
21
|
+
|------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
|
|
22
|
+
| **[Medusa Multi-head LLM](#liger-kernel)** | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s | TBA |
|
|
23
|
+
|
|
24
|
+
## Overview
|
|
25
|
+
|
|
26
|
+
### Supercharge Your Model with Liger Kernel
|
|
27
|
+
|
|
28
|
+
Gain +20% throughput and reduce memory usage by 60%. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes.
|
|
29
|
+
|
|
30
|
+
| Speed Up | Memory Reduction |
|
|
31
|
+
|--------------------------|-------------------------|
|
|
32
|
+
|  |  |
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
|
|
36
|
+
> - HuggingFace models start to OOM at a 4K context length, whereas Liger Kernel scales up to 16K.
|
|
37
|
+
> - **Fused Linear Cross Entropy Loss** is enabled to significantly reduce memory usage.
|
|
38
|
+
|
|
39
|
+
### Patch HF model with one line or use individual kernels
|
|
40
|
+
|
|
41
|
+
| Patch Existing HF Model | Compose Your Own Model |
|
|
42
|
+
|--------------------------|-------------------------|
|
|
43
|
+
|  |  |
|
|
44
|
+
|
|
45
|
+
### Key Features
|
|
46
|
+
|
|
47
|
+
- **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our kernels.
|
|
48
|
+
- **Time- and memory-efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques.
|
|
49
|
+
- **Exact:** Exact kernels—no approximations. Both forward and backward are implemented with rigorous unit and convergence testing to ensure accuracy.
|
|
50
|
+
- **Lightweight:** The kernels have minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
51
|
+
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP and DeepSpeed).
|
|
52
|
+
|
|
53
|
+
### Target Audiences
|
|
54
|
+
|
|
55
|
+
- **Researchers**: Looking to compose models using efficient and reliable kernels for frontier experiments.
|
|
56
|
+
- **ML Practitioners**: Focused on maximizing GPU training efficiency with optimal, high-performance kernels.
|
|
57
|
+
- **Curious Novices**: Eager to learn how to write reliable Triton kernels to enhance training efficiency.
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
## Installation
|
|
61
|
+
|
|
62
|
+
### Dependencies
|
|
63
|
+
|
|
64
|
+
- `torch >= 2.1.2`
|
|
65
|
+
- `triton >= 2.3.0`
|
|
66
|
+
- `transformers >= 4.40.1`
|
|
67
|
+
|
|
68
|
+
To install the stable version:
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
$ pip install liger-kernel
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
To install the nightly version:
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
$ pip install liger-kernel-nightly
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Getting Started
|
|
81
|
+
|
|
82
|
+
### 1. Patch Existing Hugging Face Models
|
|
83
|
+
|
|
84
|
+
Using [patching APIs](#patching), you can swap Hugging Face model with optimized Liger Kernels.
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
|
88
|
+
from transformers import Trainer
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
|
|
92
|
+
|
|
93
|
+
# By adding this line, it automatically monkey patches the model with the optimized kernels
|
|
94
|
+
apply_liger_kernel_to_llama()
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
### 2. Compose Your Own Model
|
|
102
|
+
|
|
103
|
+
You can take individual [kernels](#kernels) to compose your models.
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
|
|
107
|
+
import torch.nn as nn
|
|
108
|
+
import torch
|
|
109
|
+
|
|
110
|
+
model = nn.Linear(128, 256).to("cuda")
|
|
111
|
+
|
|
112
|
+
# LigerFusedLinearCrossEntropyLoss fuses linear and cross entropy layers together and performs chunk-by-chunk computation to reduce memory
|
|
113
|
+
loss_fn = LigerFusedLinearCrossEntropyLoss()
|
|
114
|
+
|
|
115
|
+
input = torch.randn(4, 128, requires_grad=True, device="cuda")
|
|
116
|
+
target = torch.empty(4, dtype=torch.long, device="cuda").random_(256)
|
|
117
|
+
|
|
118
|
+
loss = loss_fn(model.weight, input, target)
|
|
119
|
+
loss.backward()
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
## Structure
|
|
124
|
+
|
|
125
|
+
### Source Code
|
|
126
|
+
|
|
127
|
+
- `ops/`: Core Triton operations.
|
|
128
|
+
- `transformers/`: PyTorch `nn.Module` implementations built on Triton operations, compliant with the `transformers` API.
|
|
129
|
+
|
|
130
|
+
### Tests
|
|
131
|
+
|
|
132
|
+
- `transformers/`: Correctness tests for the Triton-based layers.
|
|
133
|
+
- `convergence/`: Patches Hugging Face models with all kernels, runs multiple iterations, and compares weights, logits, and loss layer by layer.
|
|
134
|
+
|
|
135
|
+
### Benchmark
|
|
136
|
+
|
|
137
|
+
- `benchmark/`: Execution time and memory benchmarks compared to Hugging Face layers.
|
|
138
|
+
|
|
139
|
+
## APIs
|
|
140
|
+
|
|
141
|
+
### Patching
|
|
142
|
+
|
|
143
|
+
| **Model** | **API** | **Supported Operations** |
|
|
144
|
+
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
|
|
145
|
+
| LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
146
|
+
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
147
|
+
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
### Kernels
|
|
151
|
+
|
|
152
|
+
| **Kernel** | **API** | **Description** |
|
|
153
|
+
|---------------------------|-------------------------------------------------------------|-----------------|
|
|
154
|
+
| RMSNorm | `liger_kernel.transformers.LigerRMSNorm` | [RMSNorm Paper](https://arxiv.org/pdf/1910.07467) |
|
|
155
|
+
| RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` | [RoPE Paper](https://arxiv.org/pdf/2104.09864) |
|
|
156
|
+
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` | [SwiGLU Paper](https://arxiv.org/pdf/2002.05202) |
|
|
157
|
+
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` | [PyTorch CrossEntropyLoss Documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) |
|
|
158
|
+
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| Inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy), with additional optimizations |
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
## Note on ML Compiler
|
|
163
|
+
|
|
164
|
+
### 1. Torch Compile
|
|
165
|
+
|
|
166
|
+
Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile. In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
|
|
167
|
+
|
|
168
|
+
| Configuration | Throughput (tokens/sec) | Memory Reserved (MB) |
|
|
169
|
+
|--------------------------------|----------------------------|-------------------------|
|
|
170
|
+
| Torch Compile | 3780 | 66358 |
|
|
171
|
+
| Torch Compile + Liger Kernel | 3702 | 31000 |
|
|
172
|
+
|
|
173
|
+
> **Note:**
|
|
174
|
+
> 1. **Fused Linear Cross Entropy Loss** is enabled.
|
|
175
|
+
> 2. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
|
|
176
|
+
> 3. Tested on torch `2.5.0.dev20240731+cu118`
|
|
177
|
+
|
|
178
|
+
### 2. Lightning Thunder
|
|
179
|
+
|
|
180
|
+
*WIP*
|
|
181
|
+
|
|
182
|
+
## Contributing
|
|
183
|
+
|
|
184
|
+
[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
|
|
185
|
+
|
|
186
|
+
## Acknowledgement
|
|
187
|
+
|
|
188
|
+
- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
|
|
189
|
+
- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
## License
|
|
193
|
+
|
|
194
|
+
[BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
|
|
195
|
+
|
|
196
|
+
## Cite this work
|
|
197
|
+
|
|
198
|
+
Biblatex entry:
|
|
199
|
+
```bib
|
|
200
|
+
@software{liger2024,
|
|
201
|
+
title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
|
|
202
|
+
author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
|
|
203
|
+
url = {https://github.com/linkedin/Liger-Kernel},
|
|
204
|
+
year = {2024}
|
|
205
|
+
}
|
|
206
|
+
```
|
|
@@ -17,7 +17,7 @@ def liger_cross_entropy_kernel(
|
|
|
17
17
|
BLOCK_SIZE: tl.constexpr,
|
|
18
18
|
):
|
|
19
19
|
"""
|
|
20
|
-
This kernel computes both cross entropy loss and the gradient of the
|
|
20
|
+
This kernel computes both cross entropy loss and the gradient of the input.
|
|
21
21
|
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
|
22
22
|
|
|
23
23
|
Parameters:
|
|
@@ -34,7 +34,7 @@ def liger_cross_entropy_kernel(
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
# https://github.com/triton-lang/triton/issues/1058
|
|
37
|
-
#
|
|
37
|
+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
|
|
38
38
|
program_id = tl.program_id(0).to(tl.int64)
|
|
39
39
|
|
|
40
40
|
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
|
@@ -90,13 +90,7 @@ def liger_cross_entropy_kernel(
|
|
|
90
90
|
tl.debug_barrier()
|
|
91
91
|
|
|
92
92
|
# 5. Calculate the loss
|
|
93
|
-
# Old Approach: Problematic LogSoftmax
|
|
94
|
-
# min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
|
|
95
|
-
# This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
|
|
96
|
-
# loss = -tl.log(X_y * n_non_ignore)
|
|
97
93
|
|
|
98
|
-
# New Approach: Safe LogSoftmax
|
|
99
|
-
# Therefore, we propose to use safe logsoftmax by reordering the formula.
|
|
100
94
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
101
95
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
102
96
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
@@ -114,7 +108,7 @@ def liger_cross_entropy_kernel(
|
|
|
114
108
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
115
109
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
116
110
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
117
|
-
MAX_FUSED_SIZE = 65536 // 2 #
|
|
111
|
+
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
118
112
|
|
|
119
113
|
|
|
120
114
|
@triton.jit
|
|
@@ -184,28 +178,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
184
178
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
185
179
|
|
|
186
180
|
# ensure _input and target are contiguous in the last dimension
|
|
187
|
-
# there are examples that are NOT contiguous overall but contiguous in the last dimension
|
|
188
|
-
####################################################################
|
|
189
|
-
# tensor = torch.arange(1, 21).reshape(5, -1)
|
|
190
|
-
# print(tensor)
|
|
191
|
-
# tensor([[ 1, 2, 3, 4],
|
|
192
|
-
# [ 5, 6, 7, 8],
|
|
193
|
-
# [ 9, 10, 11, 12],
|
|
194
|
-
# [13, 14, 15, 16],
|
|
195
|
-
# [17, 18, 19, 20]])
|
|
196
|
-
# print(tensor.is_contiguous())
|
|
197
|
-
# True
|
|
198
|
-
# slice = tensor[::2, :]
|
|
199
|
-
# print(slice)
|
|
200
|
-
# tensor([[ 1, 2, 3, 4],
|
|
201
|
-
# [ 9, 10, 11, 12],
|
|
202
|
-
# [17, 18, 19, 20]])
|
|
203
|
-
# print(slice.is_contiguous())
|
|
204
|
-
# False
|
|
205
|
-
# print(slice.stride())
|
|
206
|
-
# (8, 1)
|
|
207
|
-
# slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
|
|
208
|
-
####################################################################
|
|
209
181
|
if _input.stride(-1) != 1:
|
|
210
182
|
_input = _input.contiguous()
|
|
211
183
|
if target.stride(-1) != 1:
|
|
@@ -252,10 +224,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
252
224
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
253
225
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
254
226
|
pass
|
|
227
|
+
|
|
255
228
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
256
229
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
257
|
-
# Although the Brew trainer should only perform backward once, it encounters this issue.
|
|
258
|
-
# https://github.com/triton-lang/triton/issues/4004
|
|
259
230
|
else:
|
|
260
231
|
BT, V = _input.shape
|
|
261
232
|
n_rows = BT
|
{liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py
RENAMED
|
@@ -1,8 +1,3 @@
|
|
|
1
|
-
"""Fusing the last linear layer with cross-entropy loss
|
|
2
|
-
|
|
3
|
-
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
import triton
|
|
8
3
|
|
|
@@ -11,13 +6,16 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
|
|
|
11
6
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
12
7
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
13
8
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
14
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
9
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
18
13
|
@staticmethod
|
|
19
14
|
def forward(ctx, _input, linear, target, ignore_index):
|
|
20
15
|
"""
|
|
16
|
+
Fusing the last linear layer with cross-entropy loss
|
|
17
|
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
18
|
+
|
|
21
19
|
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
|
22
20
|
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
|
23
21
|
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
|
@@ -54,6 +52,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
54
52
|
|
|
55
53
|
grad_linear = torch.zeros_like(linear, device=device)
|
|
56
54
|
grad_input = torch.zeros_like(_input, device=device)
|
|
55
|
+
|
|
56
|
+
# we use fp32 for loss accumulator
|
|
57
57
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
58
58
|
|
|
59
59
|
total_n_non_ignore = (target != ignore_index).sum().item()
|
|
@@ -1,8 +1,19 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
import triton.language as tl
|
|
4
6
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
calculate_settings,
|
|
9
|
+
compare_version,
|
|
10
|
+
ensure_contiguous,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
from triton.language.extra.libdevice import tanh
|
|
15
|
+
else:
|
|
16
|
+
from triton.language.math import tanh
|
|
6
17
|
|
|
7
18
|
|
|
8
19
|
@triton.jit
|
|
@@ -26,7 +37,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
26
37
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
27
38
|
a_cubed = a_row * a_row * a_row
|
|
28
39
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
29
|
-
tanh_result =
|
|
40
|
+
tanh_result = tanh(tanh_arg)
|
|
30
41
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
31
42
|
c_row = geglu_a * b_row
|
|
32
43
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
@@ -54,7 +65,7 @@ def _geglu_tanh_backward_kernel(
|
|
|
54
65
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
55
66
|
a_cubed = a_row * a_row * a_row
|
|
56
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
57
|
-
tanh_result =
|
|
68
|
+
tanh_result = tanh(tanh_arg)
|
|
58
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
59
70
|
|
|
60
71
|
db_row = dc_row * geglu_a
|
|
@@ -107,8 +107,8 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
107
107
|
n_rows, n_cols = X.shape
|
|
108
108
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
109
109
|
|
|
110
|
-
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=
|
|
111
|
-
r = torch.empty(n_rows, dtype=X.dtype, device=
|
|
110
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
111
|
+
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
112
112
|
|
|
113
113
|
# Check constraints.
|
|
114
114
|
assert (
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Callable
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
import triton
|
|
7
|
+
from packaging.version import Version
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
def ensure_contiguous(fn):
|
|
@@ -36,3 +39,12 @@ def calculate_settings(n):
|
|
|
36
39
|
elif BLOCK_SIZE >= 2048:
|
|
37
40
|
num_warps = 8
|
|
38
41
|
return BLOCK_SIZE, num_warps
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compare_version(package: str, operator: Callable, target: str):
|
|
45
|
+
try:
|
|
46
|
+
pkg = importlib.import_module(package)
|
|
47
|
+
except ImportError:
|
|
48
|
+
return False
|
|
49
|
+
pkg_version = Version(pkg.__version__)
|
|
50
|
+
return operator(pkg_version, Version(target))
|
|
@@ -37,6 +37,9 @@ def lce_forward(
|
|
|
37
37
|
cache_position: Optional[torch.LongTensor] = None,
|
|
38
38
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
39
39
|
r"""
|
|
40
|
+
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
|
|
41
|
+
|
|
42
|
+
|
|
40
43
|
Args:
|
|
41
44
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
42
45
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -5,23 +5,21 @@ from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
|
5
5
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
# TODO: probably rename utils.py as hf_patcher.py to be more descriptive
|
|
9
8
|
def apply_liger_kernel_to_llama(
|
|
10
9
|
rope: bool = True,
|
|
11
|
-
cross_entropy: bool =
|
|
12
|
-
fused_linear_cross_entropy: bool =
|
|
10
|
+
cross_entropy: bool = False,
|
|
11
|
+
fused_linear_cross_entropy: bool = True,
|
|
13
12
|
rms_norm: bool = True,
|
|
14
13
|
swiglu: bool = True,
|
|
15
14
|
) -> None:
|
|
16
15
|
"""
|
|
17
16
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
18
|
-
to make GPU go burrr.
|
|
19
17
|
|
|
20
18
|
Args:
|
|
21
19
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
22
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
20
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
23
21
|
fused_linear_cross_entropy (bool):
|
|
24
|
-
Whether to apply Liger's fused lienar cross entropy loss. Default is
|
|
22
|
+
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
|
|
25
23
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
26
24
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
27
25
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
@@ -54,7 +52,6 @@ def apply_liger_kernel_to_mistral(
|
|
|
54
52
|
) -> None:
|
|
55
53
|
"""
|
|
56
54
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
57
|
-
to make GPU go burrr.
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
57
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -83,12 +80,12 @@ def apply_liger_kernel_to_mixtral(
|
|
|
83
80
|
) -> None:
|
|
84
81
|
"""
|
|
85
82
|
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
|
|
86
|
-
to make GPU go burrr.
|
|
87
83
|
|
|
88
84
|
Args:
|
|
89
85
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
90
86
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
91
87
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
88
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
92
89
|
"""
|
|
93
90
|
|
|
94
91
|
from transformers.models.mixtral import modeling_mixtral
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|