liger-kernel 0.0.0__tar.gz → 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel-0.0.1/LICENSE +23 -0
  2. liger_kernel-0.0.1/NOTICE +4 -0
  3. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/PKG-INFO +3 -1
  4. liger_kernel-0.0.1/README.md +206 -0
  5. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/setup.py +1 -1
  6. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/cross_entropy.py +4 -33
  7. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +6 -6
  8. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/geglu.py +14 -3
  9. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/rms_norm.py +2 -2
  10. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/utils.py +12 -0
  11. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/model/llama.py +3 -0
  12. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/monkey_patch.py +5 -8
  13. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/PKG-INFO +3 -1
  14. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/SOURCES.txt +3 -0
  15. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/setup.cfg +0 -0
  16. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/__init__.py +0 -0
  17. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/rope.py +0 -0
  18. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/ops/swiglu.py +0 -0
  19. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/__init__.py +0 -0
  20. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  21. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  22. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/geglu.py +0 -0
  23. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
  24. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/rms_norm.py +0 -0
  25. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/rope.py +0 -0
  26. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/transformers/swiglu.py +0 -0
  27. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/triton/__init__.py +0 -0
  28. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
  29. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  30. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/requires.txt +0 -0
  31. {liger_kernel-0.0.0 → liger_kernel-0.0.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ BSD 2-CLAUSE LICENSE
2
+ Copyright 2024 LinkedIn Corporation
3
+ All Rights Reserved.
4
+ Redistribution and use in source and binary forms, with or
5
+ without modification, are permitted provided that the following
6
+ conditions are met:
7
+ 1. Redistributions of source code must retain the above copyright
8
+ notice, this list of conditions and the following disclaimer.
9
+ 2. Redistributions in binary form must reproduce the above
10
+ copyright notice, this list of conditions and the following
11
+ disclaimer in the documentation and/or other materials provided
12
+ with the distribution.
13
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
14
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
15
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
16
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
17
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
18
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
19
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,4 @@
1
+ Copyright 2024 LinkedIn Corporation
2
+ All Rights Reserved.
3
+
4
+ Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information.
@@ -1,4 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.0.0
3
+ Version: 0.0.1
4
4
  Provides-Extra: dev
5
+ License-File: LICENSE
6
+ License-File: NOTICE
@@ -0,0 +1,206 @@
1
+ # Liger Kernel
2
+
3
+ [![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
4
+
5
+
6
+ [Installation](#installation) | [Getting Started](#getting-started) | [Structure](#structure) | [APIs](#apis) | [Contributing](#contributing)
7
+
8
+ **Liger (Linkedin GPU Efficient Runtime) Kernel** is a collection of Triton kernels designed specifically for LLM training. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. The kernel works out of the box with [flash attention](https://github.com/Dao-AILab/flash-attention), PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
9
+
10
+
11
+ ### Basic
12
+
13
+ | **Example** | **Description** | **Lightning Studio** |
14
+ |------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
15
+ | **[Hugging Face Trainer](#liger-kernel)** | Increase 20% throughput and reduce memory usage by 60% with LLaMA 3 8B on the MMLU dataset using 8 A100s | TBA |
16
+ | **[Lightning Trainer](#liger-kernel)** | Increase 15% throughput and reduce memory usage by 40% with LLaMA 3 8B on the Alpaca dataset using 4 A100s | TBA |
17
+
18
+ ### Advanced
19
+
20
+ | **Example** | **Description** | **Lightning Studio** |
21
+ |------------------------------------------------|---------------------------------------------------------------------------------------------------|----------------------|
22
+ | **[Medusa Multi-head LLM](#liger-kernel)** | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s | TBA |
23
+
24
+ ## Overview
25
+
26
+ ### Supercharge Your Model with Liger Kernel
27
+
28
+ Gain +20% throughput and reduce memory usage by 60%. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes.
29
+
30
+ | Speed Up | Memory Reduction |
31
+ |--------------------------|-------------------------|
32
+ | ![Speed up](docs/images/e2e-tps.png) | ![Memory](docs/images/e2e-memory.png) |
33
+
34
+
35
+ > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
36
+ > - HuggingFace models start to OOM at a 4K context length, whereas Liger Kernel scales up to 16K.
37
+ > - **Fused Linear Cross Entropy Loss** is enabled to significantly reduce memory usage.
38
+
39
+ ### Patch HF model with one line or use individual kernels
40
+
41
+ | Patch Existing HF Model | Compose Your Own Model |
42
+ |--------------------------|-------------------------|
43
+ | ![Patch](docs/images/patch.gif) | ![Compose](docs/images/compose.gif) |
44
+
45
+ ### Key Features
46
+
47
+ - **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our kernels.
48
+ - **Time- and memory-efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques.
49
+ - **Exact:** Exact kernels—no approximations. Both forward and backward are implemented with rigorous unit and convergence testing to ensure accuracy.
50
+ - **Lightweight:** The kernels have minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
51
+ - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP and DeepSpeed).
52
+
53
+ ### Target Audiences
54
+
55
+ - **Researchers**: Looking to compose models using efficient and reliable kernels for frontier experiments.
56
+ - **ML Practitioners**: Focused on maximizing GPU training efficiency with optimal, high-performance kernels.
57
+ - **Curious Novices**: Eager to learn how to write reliable Triton kernels to enhance training efficiency.
58
+
59
+
60
+ ## Installation
61
+
62
+ ### Dependencies
63
+
64
+ - `torch >= 2.1.2`
65
+ - `triton >= 2.3.0`
66
+ - `transformers >= 4.40.1`
67
+
68
+ To install the stable version:
69
+
70
+ ```bash
71
+ $ pip install liger-kernel
72
+ ```
73
+
74
+ To install the nightly version:
75
+
76
+ ```bash
77
+ $ pip install liger-kernel-nightly
78
+ ```
79
+
80
+ ## Getting Started
81
+
82
+ ### 1. Patch Existing Hugging Face Models
83
+
84
+ Using [patching APIs](#patching), you can swap Hugging Face model with optimized Liger Kernels.
85
+
86
+ ```python
87
+ from liger_kernel.transformers import apply_liger_kernel_to_llama
88
+ from transformers import Trainer
89
+
90
+
91
+ model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")
92
+
93
+ # By adding this line, it automatically monkey patches the model with the optimized kernels
94
+ apply_liger_kernel_to_llama()
95
+ ```
96
+
97
+
98
+
99
+
100
+
101
+ ### 2. Compose Your Own Model
102
+
103
+ You can take individual [kernels](#kernels) to compose your models.
104
+
105
+ ```python
106
+ from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
107
+ import torch.nn as nn
108
+ import torch
109
+
110
+ model = nn.Linear(128, 256).to("cuda")
111
+
112
+ # LigerFusedLinearCrossEntropyLoss fuses linear and cross entropy layers together and performs chunk-by-chunk computation to reduce memory
113
+ loss_fn = LigerFusedLinearCrossEntropyLoss()
114
+
115
+ input = torch.randn(4, 128, requires_grad=True, device="cuda")
116
+ target = torch.empty(4, dtype=torch.long, device="cuda").random_(256)
117
+
118
+ loss = loss_fn(model.weight, input, target)
119
+ loss.backward()
120
+ ```
121
+
122
+
123
+ ## Structure
124
+
125
+ ### Source Code
126
+
127
+ - `ops/`: Core Triton operations.
128
+ - `transformers/`: PyTorch `nn.Module` implementations built on Triton operations, compliant with the `transformers` API.
129
+
130
+ ### Tests
131
+
132
+ - `transformers/`: Correctness tests for the Triton-based layers.
133
+ - `convergence/`: Patches Hugging Face models with all kernels, runs multiple iterations, and compares weights, logits, and loss layer by layer.
134
+
135
+ ### Benchmark
136
+
137
+ - `benchmark/`: Execution time and memory benchmarks compared to Hugging Face layers.
138
+
139
+ ## APIs
140
+
141
+ ### Patching
142
+
143
+ | **Model** | **API** | **Supported Operations** |
144
+ |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
145
+ | LLaMA (2 & 3) | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
146
+ | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
147
+ | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
148
+
149
+
150
+ ### Kernels
151
+
152
+ | **Kernel** | **API** | **Description** |
153
+ |---------------------------|-------------------------------------------------------------|-----------------|
154
+ | RMSNorm | `liger_kernel.transformers.LigerRMSNorm` | [RMSNorm Paper](https://arxiv.org/pdf/1910.07467) |
155
+ | RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` | [RoPE Paper](https://arxiv.org/pdf/2104.09864) |
156
+ | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` | [SwiGLU Paper](https://arxiv.org/pdf/2002.05202) |
157
+ | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` | [PyTorch CrossEntropyLoss Documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) |
158
+ | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| Inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy), with additional optimizations |
159
+
160
+
161
+
162
+ ## Note on ML Compiler
163
+
164
+ ### 1. Torch Compile
165
+
166
+ Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile. In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
167
+
168
+ | Configuration | Throughput (tokens/sec) | Memory Reserved (MB) |
169
+ |--------------------------------|----------------------------|-------------------------|
170
+ | Torch Compile | 3780 | 66358 |
171
+ | Torch Compile + Liger Kernel | 3702 | 31000 |
172
+
173
+ > **Note:**
174
+ > 1. **Fused Linear Cross Entropy Loss** is enabled.
175
+ > 2. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
176
+ > 3. Tested on torch `2.5.0.dev20240731+cu118`
177
+
178
+ ### 2. Lightning Thunder
179
+
180
+ *WIP*
181
+
182
+ ## Contributing
183
+
184
+ [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
185
+
186
+ ## Acknowledgement
187
+
188
+ - [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
189
+ - [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy
190
+
191
+
192
+ ## License
193
+
194
+ [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
195
+
196
+ ## Cite this work
197
+
198
+ Biblatex entry:
199
+ ```bib
200
+ @software{liger2024,
201
+ title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
202
+ author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
203
+ url = {https://github.com/linkedin/Liger-Kernel},
204
+ year = {2024}
205
+ }
206
+ ```
@@ -1,6 +1,6 @@
1
1
  from setuptools import find_namespace_packages, setup
2
2
 
3
- __version__ = "0.0.0"
3
+ __version__ = "0.0.1"
4
4
 
5
5
  setup(
6
6
  name="liger_kernel",
@@ -17,7 +17,7 @@ def liger_cross_entropy_kernel(
17
17
  BLOCK_SIZE: tl.constexpr,
18
18
  ):
19
19
  """
20
- This kernel computes both cross entropy loss and the gradient of the _input.
20
+ This kernel computes both cross entropy loss and the gradient of the input.
21
21
  We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
22
22
 
23
23
  Parameters:
@@ -34,7 +34,7 @@ def liger_cross_entropy_kernel(
34
34
  """
35
35
 
36
36
  # https://github.com/triton-lang/triton/issues/1058
37
- # Essentially if B*T*V is too large, program_id * stride will overflow out of int32
37
+ # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
38
38
  program_id = tl.program_id(0).to(tl.int64)
39
39
 
40
40
  # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
@@ -90,13 +90,7 @@ def liger_cross_entropy_kernel(
90
90
  tl.debug_barrier()
91
91
 
92
92
  # 5. Calculate the loss
93
- # Old Approach: Problematic LogSoftmax
94
- # min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
95
- # This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
96
- # loss = -tl.log(X_y * n_non_ignore)
97
93
 
98
- # New Approach: Safe LogSoftmax
99
- # Therefore, we propose to use safe logsoftmax by reordering the formula.
100
94
  # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
101
95
  # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
102
96
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
@@ -114,7 +108,7 @@ def liger_cross_entropy_kernel(
114
108
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
115
109
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
116
110
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
117
- MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
111
+ MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
118
112
 
119
113
 
120
114
  @triton.jit
@@ -184,28 +178,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
184
178
  n_non_ignore = (target != ignore_index).sum().item()
185
179
 
186
180
  # ensure _input and target are contiguous in the last dimension
187
- # there are examples that are NOT contiguous overall but contiguous in the last dimension
188
- ####################################################################
189
- # tensor = torch.arange(1, 21).reshape(5, -1)
190
- # print(tensor)
191
- # tensor([[ 1, 2, 3, 4],
192
- # [ 5, 6, 7, 8],
193
- # [ 9, 10, 11, 12],
194
- # [13, 14, 15, 16],
195
- # [17, 18, 19, 20]])
196
- # print(tensor.is_contiguous())
197
- # True
198
- # slice = tensor[::2, :]
199
- # print(slice)
200
- # tensor([[ 1, 2, 3, 4],
201
- # [ 9, 10, 11, 12],
202
- # [17, 18, 19, 20]])
203
- # print(slice.is_contiguous())
204
- # False
205
- # print(slice.stride())
206
- # (8, 1)
207
- # slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
208
- ####################################################################
209
181
  if _input.stride(-1) != 1:
210
182
  _input = _input.contiguous()
211
183
  if target.stride(-1) != 1:
@@ -252,10 +224,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
252
224
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
253
225
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
254
226
  pass
227
+
255
228
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
256
229
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
257
- # Although the Brew trainer should only perform backward once, it encounters this issue.
258
- # https://github.com/triton-lang/triton/issues/4004
259
230
  else:
260
231
  BT, V = _input.shape
261
232
  n_rows = BT
@@ -1,8 +1,3 @@
1
- """Fusing the last linear layer with cross-entropy loss
2
-
3
- Reference: https://github.com/mgmalek/efficient_cross_entropy
4
- """
5
-
6
1
  import torch
7
2
  import triton
8
3
 
@@ -11,13 +6,16 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
11
6
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
12
7
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
13
8
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
14
- MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
9
+ MAX_FUSED_SIZE = 65536 // 2
15
10
 
16
11
 
17
12
  class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
18
13
  @staticmethod
19
14
  def forward(ctx, _input, linear, target, ignore_index):
20
15
  """
16
+ Fusing the last linear layer with cross-entropy loss
17
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
18
+
21
19
  Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
22
20
  the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
23
21
  compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
@@ -54,6 +52,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
54
52
 
55
53
  grad_linear = torch.zeros_like(linear, device=device)
56
54
  grad_input = torch.zeros_like(_input, device=device)
55
+
56
+ # we use fp32 for loss accumulator
57
57
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
58
58
 
59
59
  total_n_non_ignore = (target != ignore_index).sum().item()
@@ -1,8 +1,19 @@
1
+ import operator
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
4
6
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
7
+ from liger_kernel.ops.utils import (
8
+ calculate_settings,
9
+ compare_version,
10
+ ensure_contiguous,
11
+ )
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0"):
14
+ from triton.language.extra.libdevice import tanh
15
+ else:
16
+ from triton.language.math import tanh
6
17
 
7
18
 
8
19
  @triton.jit
@@ -26,7 +37,7 @@ def _geglu_tanh_forward_kernel(
26
37
  sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
27
38
  a_cubed = a_row * a_row * a_row
28
39
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
29
- tanh_result = tl.math.tanh(tanh_arg)
40
+ tanh_result = tanh(tanh_arg)
30
41
  geglu_a = 0.5 * a_row * (1 + tanh_result)
31
42
  c_row = geglu_a * b_row
32
43
  tl.store(c + col_offsets, c_row, mask=mask)
@@ -54,7 +65,7 @@ def _geglu_tanh_backward_kernel(
54
65
  sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
55
66
  a_cubed = a_row * a_row * a_row
56
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
57
- tanh_result = tl.math.tanh(tanh_arg)
68
+ tanh_result = tanh(tanh_arg)
58
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
59
70
 
60
71
  db_row = dc_row * geglu_a
@@ -107,8 +107,8 @@ class LigerRMSNormFunction(torch.autograd.Function):
107
107
  n_rows, n_cols = X.shape
108
108
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
109
109
 
110
- Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda")
111
- r = torch.empty(n_rows, dtype=X.dtype, device="cuda")
110
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
111
+ r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
112
112
 
113
113
  # Check constraints.
114
114
  assert (
@@ -1,7 +1,10 @@
1
1
  import functools
2
+ import importlib
3
+ from typing import Callable
2
4
 
3
5
  import torch
4
6
  import triton
7
+ from packaging.version import Version
5
8
 
6
9
 
7
10
  def ensure_contiguous(fn):
@@ -36,3 +39,12 @@ def calculate_settings(n):
36
39
  elif BLOCK_SIZE >= 2048:
37
40
  num_warps = 8
38
41
  return BLOCK_SIZE, num_warps
42
+
43
+
44
+ def compare_version(package: str, operator: Callable, target: str):
45
+ try:
46
+ pkg = importlib.import_module(package)
47
+ except ImportError:
48
+ return False
49
+ pkg_version = Version(pkg.__version__)
50
+ return operator(pkg_version, Version(target))
@@ -37,6 +37,9 @@ def lce_forward(
37
37
  cache_position: Optional[torch.LongTensor] = None,
38
38
  ) -> Union[Tuple, CausalLMOutputWithPast]:
39
39
  r"""
40
+ Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
41
+
42
+
40
43
  Args:
41
44
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
42
45
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -5,23 +5,21 @@ from liger_kernel.transformers.rope import liger_rotary_pos_emb
5
5
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
6
6
 
7
7
 
8
- # TODO: probably rename utils.py as hf_patcher.py to be more descriptive
9
8
  def apply_liger_kernel_to_llama(
10
9
  rope: bool = True,
11
- cross_entropy: bool = True,
12
- fused_linear_cross_entropy: bool = False,
10
+ cross_entropy: bool = False,
11
+ fused_linear_cross_entropy: bool = True,
13
12
  rms_norm: bool = True,
14
13
  swiglu: bool = True,
15
14
  ) -> None:
16
15
  """
17
16
  Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
18
- to make GPU go burrr.
19
17
 
20
18
  Args:
21
19
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
22
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
20
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
23
21
  fused_linear_cross_entropy (bool):
24
- Whether to apply Liger's fused lienar cross entropy loss. Default is False.
22
+ Whether to apply Liger's fused lienar cross entropy loss. Default is True.
25
23
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
26
24
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
27
25
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
@@ -54,7 +52,6 @@ def apply_liger_kernel_to_mistral(
54
52
  ) -> None:
55
53
  """
56
54
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
57
- to make GPU go burrr.
58
55
 
59
56
  Args:
60
57
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -83,12 +80,12 @@ def apply_liger_kernel_to_mixtral(
83
80
  ) -> None:
84
81
  """
85
82
  Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
86
- to make GPU go burrr.
87
83
 
88
84
  Args:
89
85
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
90
86
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
91
87
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
88
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
92
89
  """
93
90
 
94
91
  from transformers.models.mixtral import modeling_mixtral
@@ -1,4 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger-kernel
3
- Version: 0.0.0
3
+ Version: 0.0.1
4
4
  Provides-Extra: dev
5
+ License-File: LICENSE
6
+ License-File: NOTICE
@@ -1,3 +1,6 @@
1
+ LICENSE
2
+ NOTICE
3
+ README.md
1
4
  setup.py
2
5
  src/liger_kernel.egg-info/PKG-INFO
3
6
  src/liger_kernel.egg-info/SOURCES.txt
File without changes