liger-kernel 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel-0.0.0/PKG-INFO +4 -0
- liger_kernel-0.0.0/setup.cfg +4 -0
- liger_kernel-0.0.0/setup.py +26 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/__init__.py +0 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/cross_entropy.py +277 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/fused_linear_cross_entropy.py +161 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/geglu.py +129 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/rms_norm.py +167 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/rope.py +234 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/swiglu.py +113 -0
- liger_kernel-0.0.0/src/liger_kernel/ops/utils.py +38 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/__init__.py +5 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/cross_entropy.py +11 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/fused_linear_cross_entropy.py +15 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/geglu.py +23 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/model/__init__.py +0 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/model/llama.py +143 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/monkey_patch.py +103 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/rms_norm.py +16 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/rope.py +20 -0
- liger_kernel-0.0.0/src/liger_kernel/transformers/swiglu.py +40 -0
- liger_kernel-0.0.0/src/liger_kernel/triton/__init__.py +3 -0
- liger_kernel-0.0.0/src/liger_kernel/triton/monkey_patch.py +44 -0
- liger_kernel-0.0.0/src/liger_kernel.egg-info/PKG-INFO +4 -0
- liger_kernel-0.0.0/src/liger_kernel.egg-info/SOURCES.txt +26 -0
- liger_kernel-0.0.0/src/liger_kernel.egg-info/dependency_links.txt +1 -0
- liger_kernel-0.0.0/src/liger_kernel.egg-info/requires.txt +11 -0
- liger_kernel-0.0.0/src/liger_kernel.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from setuptools import find_namespace_packages, setup
|
|
2
|
+
|
|
3
|
+
__version__ = "0.0.0"
|
|
4
|
+
|
|
5
|
+
setup(
|
|
6
|
+
name="liger_kernel",
|
|
7
|
+
version=__version__,
|
|
8
|
+
package_dir={"": "src"},
|
|
9
|
+
packages=find_namespace_packages(where="src"),
|
|
10
|
+
include_package_data=True,
|
|
11
|
+
install_requires=[
|
|
12
|
+
"torch>=2.1.2",
|
|
13
|
+
"triton>=2.3.0",
|
|
14
|
+
"transformers>=4.40.1",
|
|
15
|
+
],
|
|
16
|
+
extras_require={
|
|
17
|
+
"dev": [
|
|
18
|
+
"matplotlib>=3.7.2",
|
|
19
|
+
"flake8>=4.0.1.1",
|
|
20
|
+
"black>=24.4.2",
|
|
21
|
+
"isort>=5.13.2",
|
|
22
|
+
"pre-commit>=3.7.1",
|
|
23
|
+
"torch-tb-profiler>=0.4.1",
|
|
24
|
+
]
|
|
25
|
+
},
|
|
26
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def liger_cross_entropy_kernel(
|
|
8
|
+
X_ptr,
|
|
9
|
+
X_stride,
|
|
10
|
+
Y_ptr,
|
|
11
|
+
Y_stride,
|
|
12
|
+
loss_ptr,
|
|
13
|
+
loss_stride,
|
|
14
|
+
n_cols,
|
|
15
|
+
n_non_ignore,
|
|
16
|
+
ignore_index,
|
|
17
|
+
BLOCK_SIZE: tl.constexpr,
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
This kernel computes both cross entropy loss and the gradient of the _input.
|
|
21
|
+
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
X_ptr: Pointer to input tensor.
|
|
25
|
+
X_stride (int): The stride of the input tensor.
|
|
26
|
+
Y_ptr: Pointer to target tensor.
|
|
27
|
+
Y_stride (int): The stride of the target tensor.
|
|
28
|
+
loss_ptr: Pointer to tensor to store the loss.
|
|
29
|
+
loss_stride (int): The stride of the loss tensor.
|
|
30
|
+
n_cols (int): The number of columns in the input tensor.
|
|
31
|
+
n_non_ignore (int): The number of non-ignored elements in the batch.
|
|
32
|
+
ignore_index (int): The index to ignore in the target.
|
|
33
|
+
BLOCK_SIZE (int): The block size for Triton operations.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# https://github.com/triton-lang/triton/issues/1058
|
|
37
|
+
# Essentially if B*T*V is too large, program_id * stride will overflow out of int32
|
|
38
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
39
|
+
|
|
40
|
+
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
|
41
|
+
Y_ptr += program_id * Y_stride
|
|
42
|
+
y = tl.load(Y_ptr)
|
|
43
|
+
|
|
44
|
+
# 2. locate the start index
|
|
45
|
+
X_ptr += program_id * X_stride
|
|
46
|
+
|
|
47
|
+
if y == ignore_index:
|
|
48
|
+
# set all X_ptr as 0
|
|
49
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
50
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
51
|
+
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
loss_ptr += program_id * loss_stride
|
|
55
|
+
|
|
56
|
+
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
57
|
+
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
58
|
+
|
|
59
|
+
# 3. [Oneline softmax] first pass: find max + sum
|
|
60
|
+
m = float("-inf") # m is the max value. use the notation from the paper
|
|
61
|
+
d = 0.0 # d is the sum. use the notation from the paper
|
|
62
|
+
ori_X_y = tl.load(
|
|
63
|
+
X_ptr + y
|
|
64
|
+
) # we need to store the original value of X_y for the loss calculation
|
|
65
|
+
|
|
66
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
67
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
68
|
+
X_block = tl.load(
|
|
69
|
+
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
70
|
+
)
|
|
71
|
+
block_max = tl.max(X_block)
|
|
72
|
+
m_new = tl.maximum(m, block_max)
|
|
73
|
+
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
74
|
+
m = m_new
|
|
75
|
+
|
|
76
|
+
# 4. [Oneline softmax] second pass: calculate the gradients
|
|
77
|
+
# dx_y = (softmax(x_y) - 1) / N
|
|
78
|
+
# dx_i = softmax(x_i) / N, i != y
|
|
79
|
+
# N is the number of non ingored elements in the batch
|
|
80
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
81
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
82
|
+
X_block = tl.load(
|
|
83
|
+
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
84
|
+
)
|
|
85
|
+
X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
|
|
86
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
87
|
+
|
|
88
|
+
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
89
|
+
# ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
90
|
+
tl.debug_barrier()
|
|
91
|
+
|
|
92
|
+
# 5. Calculate the loss
|
|
93
|
+
# Old Approach: Problematic LogSoftmax
|
|
94
|
+
# min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
|
|
95
|
+
# This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
|
|
96
|
+
# loss = -tl.log(X_y * n_non_ignore)
|
|
97
|
+
|
|
98
|
+
# New Approach: Safe LogSoftmax
|
|
99
|
+
# Therefore, we propose to use safe logsoftmax by reordering the formula.
|
|
100
|
+
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
101
|
+
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
102
|
+
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
103
|
+
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
104
|
+
loss = -(ori_X_y - m - tl.log(d))
|
|
105
|
+
|
|
106
|
+
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
|
|
107
|
+
X_y = tl.load(X_ptr + y)
|
|
108
|
+
X_y += -1 / (n_non_ignore)
|
|
109
|
+
|
|
110
|
+
tl.store(loss_ptr, loss)
|
|
111
|
+
tl.store(X_ptr + y, X_y)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
115
|
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
116
|
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
117
|
+
MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@triton.jit
|
|
121
|
+
def element_mul(
|
|
122
|
+
X_ptr,
|
|
123
|
+
X_stride,
|
|
124
|
+
grad_output_ptr,
|
|
125
|
+
n_cols,
|
|
126
|
+
BLOCK_SIZE: tl.constexpr,
|
|
127
|
+
):
|
|
128
|
+
"""
|
|
129
|
+
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
|
|
130
|
+
The multiplication is performed in-place on the tensor pointed by X_ptr.
|
|
131
|
+
|
|
132
|
+
Parameters:
|
|
133
|
+
X_ptr: Pointer to the input tensor.
|
|
134
|
+
X_stride (int): The stride of the input tensor.
|
|
135
|
+
grad_output_ptr: Pointer to the gradient output value.
|
|
136
|
+
n_cols (int): The number of columns in the input tensor.
|
|
137
|
+
BLOCK_SIZE (int): The block size for Triton operations.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
# Get the program ID and convert it to int64 to avoid overflow
|
|
141
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
142
|
+
|
|
143
|
+
# Locate the start index
|
|
144
|
+
X_ptr += program_id * X_stride
|
|
145
|
+
|
|
146
|
+
# Load the gradient output value
|
|
147
|
+
grad_output = tl.load(grad_output_ptr)
|
|
148
|
+
|
|
149
|
+
# Perform the element-wise multiplication
|
|
150
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
151
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
152
|
+
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
153
|
+
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
157
|
+
"""
|
|
158
|
+
This class implements a custom autograd function for the Liger Cross Entropy loss.
|
|
159
|
+
It overrides the forward and backward methods of the torch.autograd.Function class.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def forward(ctx, _input, target, ignore_index):
|
|
164
|
+
"""
|
|
165
|
+
The forward pass of the Liger Cross Entropy loss.
|
|
166
|
+
|
|
167
|
+
Parameters:
|
|
168
|
+
ctx : The context object.
|
|
169
|
+
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
170
|
+
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
171
|
+
ignore_index (int): The index to ignore in the target.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
tensor: The computed loss.
|
|
175
|
+
"""
|
|
176
|
+
BT, V = _input.shape
|
|
177
|
+
n_rows = BT
|
|
178
|
+
|
|
179
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
180
|
+
|
|
181
|
+
# unreduced loss
|
|
182
|
+
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
183
|
+
|
|
184
|
+
n_non_ignore = (target != ignore_index).sum().item()
|
|
185
|
+
|
|
186
|
+
# ensure _input and target are contiguous in the last dimension
|
|
187
|
+
# there are examples that are NOT contiguous overall but contiguous in the last dimension
|
|
188
|
+
####################################################################
|
|
189
|
+
# tensor = torch.arange(1, 21).reshape(5, -1)
|
|
190
|
+
# print(tensor)
|
|
191
|
+
# tensor([[ 1, 2, 3, 4],
|
|
192
|
+
# [ 5, 6, 7, 8],
|
|
193
|
+
# [ 9, 10, 11, 12],
|
|
194
|
+
# [13, 14, 15, 16],
|
|
195
|
+
# [17, 18, 19, 20]])
|
|
196
|
+
# print(tensor.is_contiguous())
|
|
197
|
+
# True
|
|
198
|
+
# slice = tensor[::2, :]
|
|
199
|
+
# print(slice)
|
|
200
|
+
# tensor([[ 1, 2, 3, 4],
|
|
201
|
+
# [ 9, 10, 11, 12],
|
|
202
|
+
# [17, 18, 19, 20]])
|
|
203
|
+
# print(slice.is_contiguous())
|
|
204
|
+
# False
|
|
205
|
+
# print(slice.stride())
|
|
206
|
+
# (8, 1)
|
|
207
|
+
# slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
|
|
208
|
+
####################################################################
|
|
209
|
+
if _input.stride(-1) != 1:
|
|
210
|
+
_input = _input.contiguous()
|
|
211
|
+
if target.stride(-1) != 1:
|
|
212
|
+
target = target.contiguous()
|
|
213
|
+
|
|
214
|
+
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
|
215
|
+
liger_cross_entropy_kernel[(n_rows,)](
|
|
216
|
+
X_ptr=_input,
|
|
217
|
+
X_stride=_input.stride(-2),
|
|
218
|
+
Y_ptr=target,
|
|
219
|
+
Y_stride=target.stride(-1), # always 1
|
|
220
|
+
loss_ptr=loss_1d,
|
|
221
|
+
loss_stride=loss_1d.stride(-1), # always 1
|
|
222
|
+
n_cols=V,
|
|
223
|
+
n_non_ignore=n_non_ignore,
|
|
224
|
+
ignore_index=ignore_index,
|
|
225
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
226
|
+
# TODO: 32 seems to give the best performance
|
|
227
|
+
# Performance is quite sentitive to num_warps
|
|
228
|
+
num_warps=32,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
loss = torch.sum(loss_1d) / n_non_ignore
|
|
232
|
+
|
|
233
|
+
# TODO: investigation
|
|
234
|
+
# If we don't detach the _input tensor, the memory will double
|
|
235
|
+
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
236
|
+
ctx.save_for_backward(_input.detach())
|
|
237
|
+
return loss
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def backward(ctx, grad_output):
|
|
241
|
+
"""
|
|
242
|
+
The backward pass of the Liger Cross Entropy loss.
|
|
243
|
+
|
|
244
|
+
Parameters:
|
|
245
|
+
ctx : The context object with saved tensors.
|
|
246
|
+
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
250
|
+
"""
|
|
251
|
+
(_input,) = ctx.saved_tensors
|
|
252
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
253
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
254
|
+
pass
|
|
255
|
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
256
|
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
257
|
+
# Although the Brew trainer should only perform backward once, it encounters this issue.
|
|
258
|
+
# https://github.com/triton-lang/triton/issues/4004
|
|
259
|
+
else:
|
|
260
|
+
BT, V = _input.shape
|
|
261
|
+
n_rows = BT
|
|
262
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
263
|
+
|
|
264
|
+
element_mul[(n_rows,)](
|
|
265
|
+
_input,
|
|
266
|
+
_input.stride(-2),
|
|
267
|
+
grad_output,
|
|
268
|
+
V,
|
|
269
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
270
|
+
num_warps=32,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return (
|
|
274
|
+
_input,
|
|
275
|
+
None,
|
|
276
|
+
None,
|
|
277
|
+
)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Fusing the last linear layer with cross-entropy loss
|
|
2
|
+
|
|
3
|
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kernel
|
|
10
|
+
|
|
11
|
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
12
|
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
13
|
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
14
|
+
MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
18
|
+
@staticmethod
|
|
19
|
+
def forward(ctx, _input, linear, target, ignore_index):
|
|
20
|
+
"""
|
|
21
|
+
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
|
22
|
+
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
|
23
|
+
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
|
24
|
+
for the backward pass.
|
|
25
|
+
|
|
26
|
+
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
|
|
27
|
+
target: (B*T) where each value is in [0, V-1]
|
|
28
|
+
linear: linear projection matrix of shape V x H.
|
|
29
|
+
ignore_index: the index to ignore in the target
|
|
30
|
+
"""
|
|
31
|
+
dtype = (
|
|
32
|
+
torch.get_autocast_gpu_dtype()
|
|
33
|
+
if torch.is_autocast_enabled()
|
|
34
|
+
else _input.dtype
|
|
35
|
+
)
|
|
36
|
+
device = _input.device
|
|
37
|
+
|
|
38
|
+
# inputs have shape: BT x H
|
|
39
|
+
# materialized activations will have shape: BT x V
|
|
40
|
+
# the increase in memory = BT x V
|
|
41
|
+
# reduction can be achieved by paritioning the number of tokens BT into smaller chunks.
|
|
42
|
+
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
|
|
43
|
+
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
|
|
44
|
+
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
|
|
45
|
+
BT, H = _input.shape
|
|
46
|
+
V = linear.shape[0]
|
|
47
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
48
|
+
|
|
49
|
+
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
50
|
+
chunk_size = triton.next_power_of_2(
|
|
51
|
+
triton.cdiv(BT, inc_factor)
|
|
52
|
+
) # (BT + inc_factor - 1) // inc_factor
|
|
53
|
+
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
54
|
+
|
|
55
|
+
grad_linear = torch.zeros_like(linear, device=device)
|
|
56
|
+
grad_input = torch.zeros_like(_input, device=device)
|
|
57
|
+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
58
|
+
|
|
59
|
+
total_n_non_ignore = (target != ignore_index).sum().item()
|
|
60
|
+
|
|
61
|
+
for chunk_id in range(num_chunks):
|
|
62
|
+
start_idx = chunk_id * chunk_size
|
|
63
|
+
end_idx = min((chunk_id + 1) * chunk_size, BT)
|
|
64
|
+
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
|
|
65
|
+
|
|
66
|
+
# when doing matmul, use the original precision
|
|
67
|
+
logits_chunk = _input_chunk @ linear.t() # chunk_size x V
|
|
68
|
+
target_chunk = target[start_idx:end_idx] # chunk_size,
|
|
69
|
+
|
|
70
|
+
n_rows = logits_chunk.shape[0]
|
|
71
|
+
|
|
72
|
+
# unreduced loss
|
|
73
|
+
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
74
|
+
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
|
75
|
+
|
|
76
|
+
# when doing CE, use the upcasted precision
|
|
77
|
+
logits_chunk = logits_chunk.float()
|
|
78
|
+
|
|
79
|
+
# ensure _input and target are contiguous
|
|
80
|
+
logits_chunk = logits_chunk.contiguous()
|
|
81
|
+
target_chunk = target_chunk.contiguous()
|
|
82
|
+
|
|
83
|
+
# Here we calculate the gradient of logits_chunk in place so we can save memory.
|
|
84
|
+
liger_cross_entropy_kernel[(n_rows,)](
|
|
85
|
+
X_ptr=logits_chunk,
|
|
86
|
+
X_stride=logits_chunk.stride(-2),
|
|
87
|
+
Y_ptr=target_chunk,
|
|
88
|
+
Y_stride=target_chunk.stride(-1), # always 1
|
|
89
|
+
loss_ptr=loss_1d_slice,
|
|
90
|
+
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
91
|
+
n_cols=V,
|
|
92
|
+
n_non_ignore=n_non_ignore,
|
|
93
|
+
ignore_index=ignore_index,
|
|
94
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
95
|
+
num_warps=32,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# gradient of logits_chunk is computed inplace by the above triton kernel.
|
|
99
|
+
# Following HuggingFace model source code, we do the forward and backward
|
|
100
|
+
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
|
|
101
|
+
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
|
102
|
+
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
|
103
|
+
logits_chunk = logits_chunk.to(dtype)
|
|
104
|
+
|
|
105
|
+
# gradient of logits_chunk is computed inplace by the above triton kernel and is of shape: chunk_size x V
|
|
106
|
+
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
|
107
|
+
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
|
108
|
+
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
|
|
109
|
+
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
|
110
|
+
grad_logits_chunk = logits_chunk * (n_non_ignore / total_n_non_ignore)
|
|
111
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ linear
|
|
112
|
+
|
|
113
|
+
torch.addmm(
|
|
114
|
+
input=grad_linear,
|
|
115
|
+
mat1=logits_chunk.t(),
|
|
116
|
+
mat2=_input_chunk,
|
|
117
|
+
out=grad_linear,
|
|
118
|
+
alpha=n_non_ignore / total_n_non_ignore,
|
|
119
|
+
beta=1.0,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
loss = torch.sum(loss_1d) / total_n_non_ignore
|
|
123
|
+
|
|
124
|
+
# downcast to dtype and store for backward
|
|
125
|
+
ctx.save_for_backward(grad_input.detach(), grad_linear.detach())
|
|
126
|
+
return loss
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def backward(ctx, grad_output):
|
|
130
|
+
(grad_input, grad_linear) = ctx.saved_tensors
|
|
131
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
132
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
133
|
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
134
|
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
135
|
+
BT, H = grad_input.shape
|
|
136
|
+
n_rows = BT
|
|
137
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
|
|
138
|
+
|
|
139
|
+
element_mul[(n_rows,)](
|
|
140
|
+
grad_input,
|
|
141
|
+
grad_input.stride(-2),
|
|
142
|
+
grad_output,
|
|
143
|
+
H,
|
|
144
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
145
|
+
num_warps=32,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# handle grad_linear
|
|
149
|
+
V, H = grad_linear.shape
|
|
150
|
+
n_rows = V
|
|
151
|
+
|
|
152
|
+
element_mul[(n_rows,)](
|
|
153
|
+
grad_linear,
|
|
154
|
+
grad_linear.stride(-2),
|
|
155
|
+
grad_output,
|
|
156
|
+
H,
|
|
157
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
158
|
+
num_warps=32,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return (grad_input, grad_linear, None, None)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def _geglu_tanh_forward_kernel(
|
|
10
|
+
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
11
|
+
):
|
|
12
|
+
program_id = tl.program_id(0)
|
|
13
|
+
|
|
14
|
+
# locate start index
|
|
15
|
+
a += program_id * stride
|
|
16
|
+
b += program_id * stride
|
|
17
|
+
c += program_id * stride
|
|
18
|
+
|
|
19
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
20
|
+
mask = col_offsets < n_cols
|
|
21
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
22
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
23
|
+
|
|
24
|
+
# tanh approximation form of GELU is computed with:
|
|
25
|
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
|
26
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
27
|
+
a_cubed = a_row * a_row * a_row
|
|
28
|
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
29
|
+
tanh_result = tl.math.tanh(tanh_arg)
|
|
30
|
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
31
|
+
c_row = geglu_a * b_row
|
|
32
|
+
tl.store(c + col_offsets, c_row, mask=mask)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@triton.jit
|
|
36
|
+
def _geglu_tanh_backward_kernel(
|
|
37
|
+
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
38
|
+
):
|
|
39
|
+
program_id = tl.program_id(0)
|
|
40
|
+
|
|
41
|
+
# locate start index
|
|
42
|
+
dc += program_id * stride
|
|
43
|
+
a += program_id * stride
|
|
44
|
+
b += program_id * stride
|
|
45
|
+
|
|
46
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = col_offsets < n_cols
|
|
48
|
+
|
|
49
|
+
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
|
|
50
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
51
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
52
|
+
|
|
53
|
+
# recomputation to save memory
|
|
54
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
55
|
+
a_cubed = a_row * a_row * a_row
|
|
56
|
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
57
|
+
tanh_result = tl.math.tanh(tanh_arg)
|
|
58
|
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
59
|
+
|
|
60
|
+
db_row = dc_row * geglu_a
|
|
61
|
+
|
|
62
|
+
# Gradient w.r.t. a can be computed with:
|
|
63
|
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
64
|
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
65
|
+
term1 = 0.5 * (1 + tanh_result)
|
|
66
|
+
tanh_sq = tanh_result * tanh_result
|
|
67
|
+
term2 = (
|
|
68
|
+
0.5
|
|
69
|
+
* a_row
|
|
70
|
+
* (1 - tanh_sq)
|
|
71
|
+
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
72
|
+
)
|
|
73
|
+
da_row = dc_row * b_row * (term1 + term2)
|
|
74
|
+
|
|
75
|
+
tl.store(a + col_offsets, da_row, mask=mask)
|
|
76
|
+
tl.store(b + col_offsets, db_row, mask=mask)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
80
|
+
@staticmethod
|
|
81
|
+
@ensure_contiguous
|
|
82
|
+
def forward(ctx, a, b):
|
|
83
|
+
ori_shape = a.shape
|
|
84
|
+
|
|
85
|
+
n_cols = ori_shape[-1]
|
|
86
|
+
a = a.view(-1, n_cols)
|
|
87
|
+
b = b.view(-1, n_cols)
|
|
88
|
+
c = torch.zeros_like(a)
|
|
89
|
+
n_rows = a.shape[0]
|
|
90
|
+
|
|
91
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
92
|
+
|
|
93
|
+
_geglu_tanh_forward_kernel[(n_rows,)](
|
|
94
|
+
a,
|
|
95
|
+
b,
|
|
96
|
+
c,
|
|
97
|
+
c.stride(-2),
|
|
98
|
+
n_cols=n_cols,
|
|
99
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
100
|
+
num_warps=num_warps,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
ctx.save_for_backward(a, b)
|
|
104
|
+
|
|
105
|
+
return c.view(*ori_shape)
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
@ensure_contiguous
|
|
109
|
+
def backward(ctx, dc):
|
|
110
|
+
|
|
111
|
+
ori_shape = dc.shape
|
|
112
|
+
n_cols = ori_shape[-1]
|
|
113
|
+
dc = dc.view(-1, n_cols)
|
|
114
|
+
a, b = ctx.saved_tensors
|
|
115
|
+
n_rows = dc.shape[0]
|
|
116
|
+
|
|
117
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
118
|
+
|
|
119
|
+
_geglu_tanh_backward_kernel[(n_rows,)](
|
|
120
|
+
dc,
|
|
121
|
+
a,
|
|
122
|
+
b,
|
|
123
|
+
dc.stride(-2),
|
|
124
|
+
n_cols=n_cols,
|
|
125
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
126
|
+
num_warps=num_warps,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|