liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +54 -42
- liger_kernel/ops/kl_div.py +247 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +220 -84
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +48 -41
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +22 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +13 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +605 -14
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.3.0.dist-info/METADATA +388 -0
- liger_kernel-0.3.0.dist-info/RECORD +42 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import platform
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def print_env_report():
|
|
6
|
+
"""
|
|
7
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
8
|
+
Usage:
|
|
9
|
+
```
|
|
10
|
+
python -m liger_kernel.env_report
|
|
11
|
+
```
|
|
12
|
+
"""
|
|
13
|
+
print("Environment Report:")
|
|
14
|
+
print("-------------------")
|
|
15
|
+
print(f"Operating System: {platform.platform()}")
|
|
16
|
+
print(f"Python version: {sys.version.split()[0]}")
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
22
|
+
cuda_version = (
|
|
23
|
+
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
24
|
+
)
|
|
25
|
+
print(f"CUDA version: {cuda_version}")
|
|
26
|
+
except ImportError:
|
|
27
|
+
print("PyTorch: Not installed")
|
|
28
|
+
print("CUDA version: Unable to query")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import triton
|
|
32
|
+
|
|
33
|
+
print(f"Triton version: {triton.__version__}")
|
|
34
|
+
except ImportError:
|
|
35
|
+
print("Triton: Not installed")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import transformers
|
|
39
|
+
|
|
40
|
+
print(f"Transformers version: {transformers.__version__}")
|
|
41
|
+
except ImportError:
|
|
42
|
+
print("Transformers: Not installed")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
print_env_report()
|
|
@@ -14,6 +14,8 @@ def liger_cross_entropy_kernel(
|
|
|
14
14
|
n_cols,
|
|
15
15
|
n_non_ignore,
|
|
16
16
|
ignore_index,
|
|
17
|
+
label_smoothing: tl.constexpr,
|
|
18
|
+
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
17
19
|
BLOCK_SIZE: tl.constexpr,
|
|
18
20
|
):
|
|
19
21
|
"""
|
|
@@ -30,6 +32,8 @@ def liger_cross_entropy_kernel(
|
|
|
30
32
|
n_cols (int): The number of columns in the input tensor.
|
|
31
33
|
n_non_ignore (int): The number of non-ignored elements in the batch.
|
|
32
34
|
ignore_index (int): The index to ignore in the target.
|
|
35
|
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
36
|
+
reduction (str): The string for the reduction to apply
|
|
33
37
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
34
38
|
"""
|
|
35
39
|
|
|
@@ -56,37 +60,62 @@ def liger_cross_entropy_kernel(
|
|
|
56
60
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
57
61
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
58
62
|
|
|
59
|
-
# 3. [
|
|
63
|
+
# 3. [Online softmax] first pass: find max + sum
|
|
60
64
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
61
65
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
62
66
|
ori_X_y = tl.load(
|
|
63
67
|
X_ptr + y
|
|
64
68
|
) # we need to store the original value of X_y for the loss calculation
|
|
65
69
|
|
|
70
|
+
# Label smoothing is a general case of normal cross entropy
|
|
71
|
+
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
72
|
+
scaled_x_sum = 0.0
|
|
73
|
+
eps = label_smoothing / n_cols
|
|
74
|
+
|
|
66
75
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
67
76
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
68
77
|
X_block = tl.load(
|
|
69
78
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
70
79
|
)
|
|
71
80
|
block_max = tl.max(X_block)
|
|
81
|
+
if label_smoothing > 0:
|
|
82
|
+
# scale X beforehand to avoid overflow
|
|
83
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
|
72
84
|
m_new = tl.maximum(m, block_max)
|
|
73
85
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
74
86
|
m = m_new
|
|
75
87
|
|
|
76
|
-
# 4. [
|
|
88
|
+
# 4. [Online Softmax] Second pass: compute gradients
|
|
89
|
+
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
|
77
90
|
# dx_y = (softmax(x_y) - 1) / N
|
|
78
91
|
# dx_i = softmax(x_i) / N, i != y
|
|
79
|
-
#
|
|
92
|
+
# For label smoothing:
|
|
93
|
+
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
|
|
94
|
+
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
|
95
|
+
# = dx_i - (1 - label_smoothing) / N
|
|
96
|
+
#
|
|
97
|
+
# For 'sum' reduction, no normalization is applied:
|
|
98
|
+
# dx_y = softmax(x_y) - 1
|
|
99
|
+
# dx_i = softmax(x_i), for i ≠ y
|
|
100
|
+
# For label smoothing:
|
|
101
|
+
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
|
|
102
|
+
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
|
|
103
|
+
# = dx_i - (1 - label_smoothing)
|
|
104
|
+
|
|
80
105
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
81
106
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
82
107
|
X_block = tl.load(
|
|
83
108
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
84
109
|
)
|
|
85
|
-
|
|
110
|
+
if reduction == "mean":
|
|
111
|
+
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
|
|
112
|
+
else:
|
|
113
|
+
X_block = tl.exp(X_block - m) / d - eps
|
|
114
|
+
|
|
86
115
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
87
116
|
|
|
88
117
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
89
|
-
#
|
|
118
|
+
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
90
119
|
tl.debug_barrier()
|
|
91
120
|
|
|
92
121
|
# 5. Calculate the loss
|
|
@@ -97,9 +126,28 @@ def liger_cross_entropy_kernel(
|
|
|
97
126
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
98
127
|
loss = -(ori_X_y - m - tl.log(d))
|
|
99
128
|
|
|
100
|
-
#
|
|
129
|
+
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
130
|
+
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
131
|
+
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
132
|
+
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
133
|
+
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
|
|
134
|
+
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
135
|
+
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
136
|
+
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
137
|
+
if label_smoothing > 0:
|
|
138
|
+
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
|
|
139
|
+
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
140
|
+
|
|
141
|
+
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
142
|
+
if reduction == "mean":
|
|
143
|
+
loss = loss / n_non_ignore
|
|
144
|
+
|
|
145
|
+
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
101
146
|
X_y = tl.load(X_ptr + y)
|
|
102
|
-
|
|
147
|
+
if reduction == "mean":
|
|
148
|
+
X_y += -(1 - label_smoothing) / (n_non_ignore)
|
|
149
|
+
else:
|
|
150
|
+
X_y += -(1 - label_smoothing)
|
|
103
151
|
|
|
104
152
|
tl.store(loss_ptr, loss)
|
|
105
153
|
tl.store(X_ptr + y, X_y)
|
|
@@ -112,7 +160,7 @@ MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
|
112
160
|
|
|
113
161
|
|
|
114
162
|
@triton.jit
|
|
115
|
-
def
|
|
163
|
+
def element_mul_kernel(
|
|
116
164
|
X_ptr,
|
|
117
165
|
X_stride,
|
|
118
166
|
grad_output_ptr,
|
|
@@ -147,6 +195,70 @@ def element_mul(
|
|
|
147
195
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
148
196
|
|
|
149
197
|
|
|
198
|
+
def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
|
|
199
|
+
BT, V = _input.shape
|
|
200
|
+
n_rows = BT
|
|
201
|
+
|
|
202
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
203
|
+
|
|
204
|
+
# unreduced loss
|
|
205
|
+
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
206
|
+
|
|
207
|
+
n_non_ignore = (target != ignore_index).sum().item()
|
|
208
|
+
|
|
209
|
+
# ensure _input and target are contiguous in the last dimension
|
|
210
|
+
if _input.stride(-1) != 1:
|
|
211
|
+
_input = _input.contiguous()
|
|
212
|
+
if target.stride(-1) != 1:
|
|
213
|
+
target = target.contiguous()
|
|
214
|
+
|
|
215
|
+
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
|
216
|
+
liger_cross_entropy_kernel[(n_rows,)](
|
|
217
|
+
X_ptr=_input,
|
|
218
|
+
X_stride=_input.stride(-2),
|
|
219
|
+
Y_ptr=target,
|
|
220
|
+
Y_stride=target.stride(-1), # always 1
|
|
221
|
+
loss_ptr=loss_1d,
|
|
222
|
+
loss_stride=loss_1d.stride(-1), # always 1
|
|
223
|
+
n_cols=V,
|
|
224
|
+
n_non_ignore=n_non_ignore,
|
|
225
|
+
ignore_index=ignore_index,
|
|
226
|
+
label_smoothing=label_smoothing,
|
|
227
|
+
reduction=reduction,
|
|
228
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
229
|
+
# TODO: 32 seems to give the best performance
|
|
230
|
+
# Performance is quite sensitive to num_warps
|
|
231
|
+
num_warps=32,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
loss = torch.sum(loss_1d)
|
|
235
|
+
return loss, _input
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def cross_entropy_backward(_input, grad_output):
|
|
239
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
240
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
244
|
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
245
|
+
else:
|
|
246
|
+
BT, V = _input.shape
|
|
247
|
+
n_rows = BT
|
|
248
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
249
|
+
|
|
250
|
+
element_mul_kernel[(n_rows,)](
|
|
251
|
+
_input,
|
|
252
|
+
_input.stride(-2),
|
|
253
|
+
grad_output,
|
|
254
|
+
V,
|
|
255
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
256
|
+
num_warps=32,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
return _input
|
|
260
|
+
|
|
261
|
+
|
|
150
262
|
class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
151
263
|
"""
|
|
152
264
|
This class implements a custom autograd function for the Liger Cross Entropy loss.
|
|
@@ -154,7 +266,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
154
266
|
"""
|
|
155
267
|
|
|
156
268
|
@staticmethod
|
|
157
|
-
def forward(
|
|
269
|
+
def forward(
|
|
270
|
+
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
|
|
271
|
+
):
|
|
158
272
|
"""
|
|
159
273
|
The forward pass of the Liger Cross Entropy loss.
|
|
160
274
|
|
|
@@ -163,45 +277,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
163
277
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
164
278
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
165
279
|
ignore_index (int): The index to ignore in the target.
|
|
280
|
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
281
|
+
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
166
282
|
|
|
167
283
|
Returns:
|
|
168
284
|
tensor: The computed loss.
|
|
169
285
|
"""
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
174
|
-
|
|
175
|
-
# unreduced loss
|
|
176
|
-
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
177
|
-
|
|
178
|
-
n_non_ignore = (target != ignore_index).sum().item()
|
|
179
|
-
|
|
180
|
-
# ensure _input and target are contiguous in the last dimension
|
|
181
|
-
if _input.stride(-1) != 1:
|
|
182
|
-
_input = _input.contiguous()
|
|
183
|
-
if target.stride(-1) != 1:
|
|
184
|
-
target = target.contiguous()
|
|
185
|
-
|
|
186
|
-
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
|
187
|
-
liger_cross_entropy_kernel[(n_rows,)](
|
|
188
|
-
X_ptr=_input,
|
|
189
|
-
X_stride=_input.stride(-2),
|
|
190
|
-
Y_ptr=target,
|
|
191
|
-
Y_stride=target.stride(-1), # always 1
|
|
192
|
-
loss_ptr=loss_1d,
|
|
193
|
-
loss_stride=loss_1d.stride(-1), # always 1
|
|
194
|
-
n_cols=V,
|
|
195
|
-
n_non_ignore=n_non_ignore,
|
|
196
|
-
ignore_index=ignore_index,
|
|
197
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
-
# TODO: 32 seems to give the best performance
|
|
199
|
-
# Performance is quite sentitive to num_warps
|
|
200
|
-
num_warps=32,
|
|
286
|
+
loss, _input = cross_entropy_forward(
|
|
287
|
+
_input, target, ignore_index, label_smoothing, reduction
|
|
201
288
|
)
|
|
202
|
-
|
|
203
|
-
loss = torch.sum(loss_1d) / n_non_ignore
|
|
204
|
-
|
|
205
289
|
# TODO: investigation
|
|
206
290
|
# If we don't detach the _input tensor, the memory will double
|
|
207
291
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
@@ -221,28 +305,11 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
221
305
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
222
306
|
"""
|
|
223
307
|
(_input,) = ctx.saved_tensors
|
|
224
|
-
|
|
225
|
-
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
226
|
-
pass
|
|
227
|
-
|
|
228
|
-
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
229
|
-
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
230
|
-
else:
|
|
231
|
-
BT, V = _input.shape
|
|
232
|
-
n_rows = BT
|
|
233
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
234
|
-
|
|
235
|
-
element_mul[(n_rows,)](
|
|
236
|
-
_input,
|
|
237
|
-
_input.stride(-2),
|
|
238
|
-
grad_output,
|
|
239
|
-
V,
|
|
240
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
241
|
-
num_warps=32,
|
|
242
|
-
)
|
|
243
|
-
|
|
308
|
+
_input = cross_entropy_backward(_input, grad_output)
|
|
244
309
|
return (
|
|
245
310
|
_input,
|
|
246
311
|
None,
|
|
247
312
|
None,
|
|
313
|
+
None,
|
|
314
|
+
None,
|
|
248
315
|
)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def embedding_forward_kernel(
|
|
10
|
+
embeddings_ptr,
|
|
11
|
+
indices_ptr,
|
|
12
|
+
output_ptr,
|
|
13
|
+
n_elements,
|
|
14
|
+
embedding_dim: tl.constexpr,
|
|
15
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
16
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
17
|
+
):
|
|
18
|
+
pid_m = tl.program_id(0)
|
|
19
|
+
pid_n = tl.program_id(1)
|
|
20
|
+
|
|
21
|
+
start_m = pid_m * BLOCK_SIZE_M
|
|
22
|
+
start_n = pid_n * BLOCK_SIZE_N
|
|
23
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
24
|
+
mask_m = offsets_m < n_elements
|
|
25
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
26
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
27
|
+
mask_n = offsets_n < embedding_dim
|
|
28
|
+
|
|
29
|
+
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
30
|
+
embeddings = tl.load(
|
|
31
|
+
embeddings_ptr + embedding_offsets,
|
|
32
|
+
mask=mask_m[:, None] & mask_n[None, :],
|
|
33
|
+
other=0.0,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
37
|
+
tl.store(
|
|
38
|
+
output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@triton.jit
|
|
43
|
+
def embedding_backward_kernel(
|
|
44
|
+
grad_output_ptr,
|
|
45
|
+
grad_weight_ptr,
|
|
46
|
+
indices_ptr,
|
|
47
|
+
n_elements,
|
|
48
|
+
embedding_dim: tl.constexpr,
|
|
49
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
50
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
51
|
+
):
|
|
52
|
+
pid_m = tl.program_id(0)
|
|
53
|
+
pid_n = tl.program_id(1)
|
|
54
|
+
|
|
55
|
+
start_m = pid_m * BLOCK_SIZE_M
|
|
56
|
+
start_n = pid_n * BLOCK_SIZE_N
|
|
57
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
58
|
+
mask_m = offsets_m < n_elements
|
|
59
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
60
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
61
|
+
mask_n = offsets_n < embedding_dim
|
|
62
|
+
|
|
63
|
+
grad_output = tl.load(
|
|
64
|
+
grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
|
|
65
|
+
mask=mask_m[:, None] & mask_n[None, :],
|
|
66
|
+
other=0.0,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
70
|
+
|
|
71
|
+
tl.atomic_add(
|
|
72
|
+
grad_weight_ptr + grad_weight_offsets,
|
|
73
|
+
grad_output,
|
|
74
|
+
mask=mask_m[:, None] & mask_n[None, :],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LigerEmbeddingFunction(torch.autograd.Function):
|
|
79
|
+
@staticmethod
|
|
80
|
+
@ensure_contiguous
|
|
81
|
+
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
|
|
82
|
+
ori_shape = indices.shape
|
|
83
|
+
indices = indices.view(-1)
|
|
84
|
+
output = torch.empty(
|
|
85
|
+
indices.shape[0],
|
|
86
|
+
embeddings.shape[1],
|
|
87
|
+
device=indices.device,
|
|
88
|
+
dtype=embeddings.dtype,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
n_elements = indices.numel()
|
|
92
|
+
embedding_dim = embeddings.shape[1]
|
|
93
|
+
|
|
94
|
+
BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
|
|
95
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
96
|
+
grid = (
|
|
97
|
+
triton.cdiv(n_elements, BLOCK_SIZE_M),
|
|
98
|
+
triton.cdiv(embedding_dim, BLOCK_SIZE_N),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
embedding_forward_kernel[grid](
|
|
102
|
+
embeddings,
|
|
103
|
+
indices,
|
|
104
|
+
output,
|
|
105
|
+
n_elements,
|
|
106
|
+
embedding_dim=embedding_dim,
|
|
107
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
108
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
ctx.save_for_backward(indices, embeddings)
|
|
112
|
+
|
|
113
|
+
return output.view(*ori_shape, -1)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
@ensure_contiguous
|
|
117
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
118
|
+
indices, embedding_table = ctx.saved_tensors
|
|
119
|
+
grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
|
|
120
|
+
|
|
121
|
+
grad_weight = torch.zeros_like(embedding_table)
|
|
122
|
+
|
|
123
|
+
n_elements = indices.numel()
|
|
124
|
+
embedding_dim = embedding_table.shape[1]
|
|
125
|
+
|
|
126
|
+
BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
|
|
127
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
128
|
+
grid = (
|
|
129
|
+
triton.cdiv(n_elements, BLOCK_SIZE_M),
|
|
130
|
+
triton.cdiv(embedding_dim, BLOCK_SIZE_N),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
embedding_backward_kernel[grid](
|
|
134
|
+
grad_output,
|
|
135
|
+
grad_weight,
|
|
136
|
+
indices,
|
|
137
|
+
n_elements,
|
|
138
|
+
embedding_dim=embedding_dim,
|
|
139
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
140
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return grad_weight, None
|