liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +21 -37
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/qwen2vl_mrope.py +2 -2
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.1.dist-info/RECORD +0 -65
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
import operator
|
|
2
|
+
|
|
2
3
|
from typing import Optional
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import triton
|
|
6
7
|
import triton.language as tl
|
|
7
8
|
|
|
8
|
-
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
|
10
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
|
+
from liger_kernel.ops.utils import is_hip
|
|
9
12
|
|
|
10
13
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
14
|
try:
|
|
@@ -17,9 +20,6 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
|
17
20
|
else:
|
|
18
21
|
from triton.language.math import tanh
|
|
19
22
|
|
|
20
|
-
_TRUE = tl.constexpr(1)
|
|
21
|
-
_FALSE = tl.constexpr(0)
|
|
22
|
-
|
|
23
23
|
|
|
24
24
|
@triton.jit
|
|
25
25
|
def liger_cross_entropy_kernel(
|
|
@@ -27,11 +27,14 @@ def liger_cross_entropy_kernel(
|
|
|
27
27
|
X_stride,
|
|
28
28
|
Y_ptr,
|
|
29
29
|
Y_stride,
|
|
30
|
+
weight_ptr,
|
|
30
31
|
loss_ptr,
|
|
31
32
|
z_loss_ptr,
|
|
32
33
|
loss_stride,
|
|
33
34
|
n_cols,
|
|
34
35
|
n_non_ignore,
|
|
36
|
+
sum_non_ignore_weight,
|
|
37
|
+
weight_sum,
|
|
35
38
|
ignore_index,
|
|
36
39
|
lse_square_scale: tl.constexpr,
|
|
37
40
|
label_smoothing: tl.constexpr,
|
|
@@ -39,6 +42,7 @@ def liger_cross_entropy_kernel(
|
|
|
39
42
|
softcap,
|
|
40
43
|
RETURN_Z_LOSS: tl.constexpr,
|
|
41
44
|
BLOCK_SIZE: tl.constexpr,
|
|
45
|
+
HAS_WEIGHT: tl.constexpr,
|
|
42
46
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
43
47
|
):
|
|
44
48
|
"""
|
|
@@ -50,18 +54,22 @@ def liger_cross_entropy_kernel(
|
|
|
50
54
|
X_stride (int): The stride of the input tensor.
|
|
51
55
|
Y_ptr: Pointer to target tensor.
|
|
52
56
|
Y_stride (int): The stride of the target tensor.
|
|
57
|
+
weight_ptr: Pointer to weight tensor.
|
|
53
58
|
loss_ptr: Pointer to tensor to store the loss.
|
|
54
59
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
55
60
|
loss_stride (int): The stride of the loss tensor.
|
|
56
61
|
n_cols (int): The number of columns in the input tensor.
|
|
57
|
-
n_non_ignore (
|
|
62
|
+
n_non_ignore (flaot): The number of non-ignored elements in the batch.
|
|
63
|
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
64
|
+
weight_sum (float): The sum of weight tensor.
|
|
58
65
|
ignore_index (int): The index to ignore in the target.
|
|
59
66
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
60
67
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
61
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
62
68
|
reduction (str): The string for the reduction to apply
|
|
63
69
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
70
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
64
71
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
72
|
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
65
73
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
66
74
|
"""
|
|
67
75
|
|
|
@@ -84,7 +92,11 @@ def liger_cross_entropy_kernel(
|
|
|
84
92
|
return
|
|
85
93
|
|
|
86
94
|
loss_ptr += program_id * loss_stride
|
|
87
|
-
|
|
95
|
+
if RETURN_Z_LOSS:
|
|
96
|
+
z_loss_ptr += program_id * loss_stride
|
|
97
|
+
|
|
98
|
+
if HAS_WEIGHT:
|
|
99
|
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
88
100
|
|
|
89
101
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
90
102
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
@@ -92,9 +104,7 @@ def liger_cross_entropy_kernel(
|
|
|
92
104
|
# 3. [Online softmax] first pass: find max + sum
|
|
93
105
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
94
106
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
95
|
-
ori_X_y = tl.load(X_ptr + y).cast(
|
|
96
|
-
tl.float32
|
|
97
|
-
) # we need to store the original value of X_y for the loss calculation
|
|
107
|
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
98
108
|
if HAS_SOFTCAPPING:
|
|
99
109
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
100
110
|
|
|
@@ -116,7 +126,11 @@ def liger_cross_entropy_kernel(
|
|
|
116
126
|
block_max = tl.max(X_block)
|
|
117
127
|
if label_smoothing > 0:
|
|
118
128
|
# scale X beforehand to avoid overflow
|
|
119
|
-
|
|
129
|
+
if HAS_WEIGHT:
|
|
130
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
131
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
|
132
|
+
else:
|
|
133
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
|
120
134
|
m_new = tl.maximum(m, block_max)
|
|
121
135
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
122
136
|
m = m_new
|
|
@@ -152,18 +166,41 @@ def liger_cross_entropy_kernel(
|
|
|
152
166
|
if HAS_SOFTCAPPING:
|
|
153
167
|
intermediate = tanh(X_block / softcap)
|
|
154
168
|
X_block = softcap * intermediate
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
169
|
+
|
|
170
|
+
if not HAS_WEIGHT:
|
|
171
|
+
# softmax(x_i)
|
|
172
|
+
X_block = tl.exp(X_block - m) / d
|
|
173
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
174
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
175
|
+
# smoothing term
|
|
176
|
+
X_block += -eps
|
|
177
|
+
# special handle dx_y
|
|
178
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
179
|
+
# reduction scale
|
|
180
|
+
if reduction == "mean":
|
|
181
|
+
X_block = X_block / n_non_ignore
|
|
182
|
+
else:
|
|
183
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
184
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
185
|
+
# derivative of original_loss
|
|
186
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
187
|
+
# specially handle dx_y
|
|
188
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
189
|
+
dloss_ori = dloss_ori * weight_y
|
|
190
|
+
# derivative of smooth_loss
|
|
191
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
192
|
+
# derivative of z-loss
|
|
193
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
194
|
+
# reduction scale
|
|
195
|
+
if reduction == "mean":
|
|
196
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
197
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
198
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
199
|
+
dz_loss = dz_loss / n_non_ignore
|
|
200
|
+
# derivative of total_loss
|
|
201
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
202
|
+
|
|
203
|
+
# chain rule softcapping
|
|
167
204
|
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
168
205
|
if HAS_SOFTCAPPING:
|
|
169
206
|
X_block = X_block * (1 - intermediate * intermediate)
|
|
@@ -182,6 +219,8 @@ def liger_cross_entropy_kernel(
|
|
|
182
219
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
183
220
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
184
221
|
loss = lse - ori_X_y
|
|
222
|
+
if HAS_WEIGHT:
|
|
223
|
+
loss = weight_y * loss
|
|
185
224
|
|
|
186
225
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
187
226
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
@@ -192,20 +231,27 @@ def liger_cross_entropy_kernel(
|
|
|
192
231
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
193
232
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
194
233
|
if label_smoothing > 0:
|
|
195
|
-
|
|
234
|
+
if HAS_WEIGHT:
|
|
235
|
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
|
236
|
+
else:
|
|
237
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
196
238
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
197
239
|
|
|
198
240
|
# An auxiliary loss, z_loss
|
|
199
241
|
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
200
242
|
z_loss = lse_square_scale * lse * lse
|
|
201
|
-
loss += z_loss
|
|
202
243
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
203
244
|
if reduction == "mean":
|
|
245
|
+
if HAS_WEIGHT:
|
|
246
|
+
loss = loss / sum_non_ignore_weight
|
|
247
|
+
else:
|
|
248
|
+
loss = loss / n_non_ignore
|
|
249
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
204
250
|
z_loss = z_loss / n_non_ignore
|
|
205
|
-
|
|
251
|
+
loss += z_loss
|
|
206
252
|
|
|
207
253
|
tl.store(loss_ptr, loss)
|
|
208
|
-
if RETURN_Z_LOSS
|
|
254
|
+
if RETURN_Z_LOSS:
|
|
209
255
|
tl.store(z_loss_ptr, z_loss)
|
|
210
256
|
|
|
211
257
|
|
|
@@ -215,15 +261,10 @@ def liger_cross_entropy_kernel(
|
|
|
215
261
|
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
216
262
|
|
|
217
263
|
|
|
218
|
-
_bool_to_return_z_loss = {
|
|
219
|
-
True: _TRUE.value,
|
|
220
|
-
False: _FALSE.value,
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
|
|
224
264
|
def cross_entropy_forward(
|
|
225
265
|
_input,
|
|
226
266
|
target,
|
|
267
|
+
weight,
|
|
227
268
|
ignore_index,
|
|
228
269
|
lse_square_scale,
|
|
229
270
|
label_smoothing,
|
|
@@ -231,15 +272,7 @@ def cross_entropy_forward(
|
|
|
231
272
|
softcap,
|
|
232
273
|
return_z_loss,
|
|
233
274
|
):
|
|
234
|
-
|
|
235
|
-
assert (
|
|
236
|
-
return_z_loss in _bool_to_return_z_loss
|
|
237
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
238
|
-
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
239
|
-
else:
|
|
240
|
-
assert (
|
|
241
|
-
return_z_loss in _bool_to_return_z_loss
|
|
242
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
275
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
243
276
|
|
|
244
277
|
BT, V = _input.shape
|
|
245
278
|
n_rows = BT
|
|
@@ -248,12 +281,22 @@ def cross_entropy_forward(
|
|
|
248
281
|
|
|
249
282
|
# unreduced loss
|
|
250
283
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
251
|
-
if return_z_loss
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
284
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
285
|
+
|
|
286
|
+
target_mask = target != ignore_index
|
|
287
|
+
n_non_ignore = target_mask.sum().item()
|
|
288
|
+
sum_non_ignore_weight = n_non_ignore
|
|
289
|
+
weight_sum = 0.0
|
|
290
|
+
if weight is not None:
|
|
291
|
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
|
292
|
+
assert torch.is_floating_point(
|
|
293
|
+
weight
|
|
294
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
|
295
|
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
296
|
+
weight_sum = weight.sum().item()
|
|
297
|
+
# ensure weight is contiguous
|
|
298
|
+
if weight.stride(-1) != 1:
|
|
299
|
+
weight = weight.contiguous()
|
|
257
300
|
|
|
258
301
|
# ensure _input and target are contiguous in the last dimension
|
|
259
302
|
if _input.stride(-1) != 1:
|
|
@@ -267,18 +310,22 @@ def cross_entropy_forward(
|
|
|
267
310
|
X_stride=_input.stride(-2),
|
|
268
311
|
Y_ptr=target,
|
|
269
312
|
Y_stride=target.stride(-1), # always 1
|
|
313
|
+
weight_ptr=weight, # dummy if None
|
|
270
314
|
loss_ptr=loss_1d,
|
|
271
315
|
z_loss_ptr=z_loss_1d,
|
|
272
316
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
273
317
|
n_cols=V,
|
|
274
318
|
n_non_ignore=n_non_ignore,
|
|
319
|
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
275
320
|
ignore_index=ignore_index,
|
|
321
|
+
weight_sum=weight_sum,
|
|
276
322
|
lse_square_scale=lse_square_scale,
|
|
277
323
|
label_smoothing=label_smoothing,
|
|
278
324
|
reduction=reduction,
|
|
279
|
-
softcap=softcap
|
|
325
|
+
softcap=softcap,
|
|
280
326
|
RETURN_Z_LOSS=return_z_loss,
|
|
281
327
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
328
|
+
HAS_WEIGHT=True if weight is not None else False,
|
|
282
329
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
283
330
|
# TODO: 32 seems to give the best performance
|
|
284
331
|
# Performance is quite sensitive to num_warps
|
|
@@ -287,10 +334,10 @@ def cross_entropy_forward(
|
|
|
287
334
|
|
|
288
335
|
if reduction == "none":
|
|
289
336
|
loss = loss_1d
|
|
290
|
-
z_loss = z_loss_1d if return_z_loss
|
|
337
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
291
338
|
else:
|
|
292
339
|
loss = torch.sum(loss_1d)
|
|
293
|
-
z_loss = torch.sum(z_loss_1d) if return_z_loss
|
|
340
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
294
341
|
|
|
295
342
|
return loss, z_loss, _input
|
|
296
343
|
|
|
@@ -330,6 +377,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
330
377
|
ctx,
|
|
331
378
|
_input: torch.Tensor,
|
|
332
379
|
target: torch.Tensor,
|
|
380
|
+
weight: Optional[torch.FloatTensor],
|
|
333
381
|
ignore_index: int = -100,
|
|
334
382
|
lse_square_scale: float = 0.0,
|
|
335
383
|
label_smoothing: float = 0.0,
|
|
@@ -344,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
344
392
|
ctx : The context object.
|
|
345
393
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
346
394
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
395
|
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
|
347
396
|
ignore_index (int): The index to ignore in the target.
|
|
348
397
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
349
398
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
@@ -357,6 +406,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
357
406
|
loss, z_loss, _input = cross_entropy_forward(
|
|
358
407
|
_input,
|
|
359
408
|
target,
|
|
409
|
+
weight,
|
|
360
410
|
ignore_index,
|
|
361
411
|
lse_square_scale,
|
|
362
412
|
label_smoothing,
|
|
@@ -398,4 +448,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
398
448
|
None,
|
|
399
449
|
None,
|
|
400
450
|
None,
|
|
451
|
+
None,
|
|
401
452
|
)
|
|
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
|
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
37
|
-
tl.store(
|
|
38
|
-
output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
|
|
39
|
-
)
|
|
37
|
+
tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
@triton.jit
|
|
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
|
|
|
37
37
|
else:
|
|
38
38
|
packed_tensor_shape = (row_dim, *original_shape[1:])
|
|
39
39
|
|
|
40
|
-
packed = torch.zeros(
|
|
41
|
-
packed_tensor_shape, device=intweights.device, dtype=torch.uint8
|
|
42
|
-
)
|
|
40
|
+
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
|
|
43
41
|
unpacked = intweights.to(torch.uint8)
|
|
44
42
|
|
|
45
43
|
def lshift(t: torch.Tensor, bits: int):
|
|
@@ -327,17 +325,13 @@ def matmul_kernel(
|
|
|
327
325
|
|
|
328
326
|
|
|
329
327
|
def matmul(a, b):
|
|
330
|
-
assert
|
|
331
|
-
a.shape[1] == b.shape[0] * 4
|
|
332
|
-
), "Incompatible dimensions, the weight matrix need to be packed"
|
|
328
|
+
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
|
|
333
329
|
assert a.is_contiguous(), "Matrix A must be contiguous"
|
|
334
330
|
M, K = a.shape
|
|
335
331
|
_, N = b.shape
|
|
336
332
|
# c is in int32 to avoid any overflows or underflows
|
|
337
333
|
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
|
338
|
-
grid = lambda META: (
|
|
339
|
-
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
340
|
-
)
|
|
334
|
+
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
|
341
335
|
matmul_kernel[grid](
|
|
342
336
|
a,
|
|
343
337
|
b,
|
|
@@ -2,12 +2,10 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
is_hip,
|
|
10
|
-
)
|
|
5
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
|
6
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
|
11
9
|
|
|
12
10
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
13
11
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -19,13 +17,16 @@ def fused_linear_cross_entropy_forward(
|
|
|
19
17
|
_input,
|
|
20
18
|
weight,
|
|
21
19
|
target,
|
|
20
|
+
ce_weight=None,
|
|
22
21
|
bias=None,
|
|
23
22
|
ignore_index=-100,
|
|
24
23
|
lse_square_scale=0.0,
|
|
25
24
|
label_smoothing=0.0,
|
|
26
25
|
reduction="mean",
|
|
27
26
|
softcap=None,
|
|
27
|
+
return_z_loss=False,
|
|
28
28
|
):
|
|
29
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
29
30
|
device = _input.device
|
|
30
31
|
|
|
31
32
|
# inputs have shape: BT x H
|
|
@@ -40,21 +41,32 @@ def fused_linear_cross_entropy_forward(
|
|
|
40
41
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
41
42
|
|
|
42
43
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
43
|
-
chunk_size = triton.next_power_of_2(
|
|
44
|
-
triton.cdiv(BT, inc_factor)
|
|
45
|
-
) # (BT + inc_factor - 1) // inc_factor
|
|
44
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
|
46
45
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
47
46
|
|
|
48
|
-
grad_weight = (
|
|
49
|
-
torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
50
|
-
)
|
|
47
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
51
48
|
grad_input = torch.zeros_like(_input, device=device)
|
|
52
49
|
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
53
50
|
# we use fp32 for loss accumulator
|
|
54
51
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
52
|
+
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
53
|
+
|
|
54
|
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
55
|
+
target_mask = target != ignore_index
|
|
56
|
+
total_n_non_ignore = target_mask.sum().item()
|
|
57
|
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
|
58
|
+
ce_weight_sum = 0.0
|
|
59
|
+
if ce_weight is not None:
|
|
60
|
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
|
61
|
+
assert torch.is_floating_point(
|
|
62
|
+
ce_weight
|
|
63
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
|
64
|
+
total_sum_non_ignore_ce_weight = (
|
|
65
|
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
66
|
+
)
|
|
67
|
+
ce_weight_sum = ce_weight.sum().item()
|
|
68
|
+
if ce_weight.stride(-1) != 1:
|
|
69
|
+
ce_weight = ce_weight.contiguous()
|
|
58
70
|
|
|
59
71
|
for chunk_id in range(num_chunks):
|
|
60
72
|
start_idx = chunk_id * chunk_size
|
|
@@ -65,13 +77,14 @@ def fused_linear_cross_entropy_forward(
|
|
|
65
77
|
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
|
66
78
|
if bias is not None:
|
|
67
79
|
logits_chunk = logits_chunk + bias
|
|
80
|
+
|
|
68
81
|
target_chunk = target[start_idx:end_idx] # chunk_size,
|
|
69
82
|
|
|
70
83
|
n_rows = logits_chunk.shape[0]
|
|
71
84
|
|
|
72
85
|
# unreduced loss
|
|
73
86
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
74
|
-
|
|
87
|
+
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
75
88
|
|
|
76
89
|
# ensure _input and target are contiguous
|
|
77
90
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -83,45 +96,42 @@ def fused_linear_cross_entropy_forward(
|
|
|
83
96
|
X_stride=logits_chunk.stride(-2),
|
|
84
97
|
Y_ptr=target_chunk,
|
|
85
98
|
Y_stride=target_chunk.stride(-1), # always 1
|
|
99
|
+
weight_ptr=ce_weight,
|
|
86
100
|
loss_ptr=loss_1d_slice,
|
|
87
|
-
z_loss_ptr=
|
|
101
|
+
z_loss_ptr=z_loss_1d_slice,
|
|
88
102
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
89
103
|
n_cols=V,
|
|
90
|
-
n_non_ignore=
|
|
104
|
+
n_non_ignore=total_n_non_ignore,
|
|
105
|
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
106
|
+
weight_sum=ce_weight_sum,
|
|
91
107
|
ignore_index=ignore_index,
|
|
92
108
|
lse_square_scale=lse_square_scale,
|
|
93
109
|
label_smoothing=label_smoothing,
|
|
94
110
|
reduction=reduction,
|
|
95
|
-
softcap=softcap
|
|
96
|
-
RETURN_Z_LOSS=
|
|
111
|
+
softcap=softcap,
|
|
112
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
113
|
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
97
114
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
98
115
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
99
116
|
num_warps=32 if not is_hip() else 16,
|
|
100
117
|
)
|
|
101
118
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
|
107
|
-
|
|
108
|
-
if reduction == "mean":
|
|
109
|
-
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
|
|
110
|
-
else:
|
|
111
|
-
alpha = 1.0
|
|
112
|
-
|
|
113
|
-
loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
|
|
114
|
-
grad_logits_chunk = logits_chunk * alpha # chunk_size x V
|
|
119
|
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
120
|
+
if return_z_loss:
|
|
121
|
+
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
122
|
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
115
123
|
|
|
116
124
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
117
125
|
|
|
118
126
|
if grad_weight is not None:
|
|
119
127
|
torch.addmm(
|
|
120
128
|
input=grad_weight,
|
|
121
|
-
mat1=logits_chunk.t()
|
|
129
|
+
mat1=logits_chunk.t().to(
|
|
130
|
+
_input_chunk.dtype
|
|
131
|
+
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
|
122
132
|
mat2=_input_chunk,
|
|
123
133
|
out=grad_weight,
|
|
124
|
-
alpha=
|
|
134
|
+
alpha=1.0,
|
|
125
135
|
beta=1.0,
|
|
126
136
|
)
|
|
127
137
|
|
|
@@ -130,18 +140,22 @@ def fused_linear_cross_entropy_forward(
|
|
|
130
140
|
input=grad_bias,
|
|
131
141
|
other=logits_chunk.sum(dim=0),
|
|
132
142
|
out=grad_bias,
|
|
133
|
-
alpha=
|
|
143
|
+
alpha=1.0,
|
|
134
144
|
)
|
|
135
145
|
|
|
136
|
-
|
|
137
|
-
|
|
146
|
+
if reduction == "none":
|
|
147
|
+
loss = loss_1d
|
|
148
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
138
149
|
|
|
150
|
+
else:
|
|
151
|
+
loss = torch.sum(loss_1d)
|
|
152
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
153
|
+
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
139
154
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
):
|
|
155
|
+
|
|
156
|
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
143
157
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
144
|
-
if torch.
|
|
158
|
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
145
159
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
146
160
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
147
161
|
BT, H = grad_input.shape
|
|
@@ -195,11 +209,13 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
195
209
|
weight,
|
|
196
210
|
target,
|
|
197
211
|
bias=None,
|
|
212
|
+
ce_weight=None,
|
|
198
213
|
ignore_index=-100,
|
|
199
214
|
lse_square_scale=0.0,
|
|
200
215
|
label_smoothing=0.0,
|
|
201
216
|
reduction="mean",
|
|
202
217
|
softcap=None,
|
|
218
|
+
return_z_loss: bool = False,
|
|
203
219
|
):
|
|
204
220
|
"""
|
|
205
221
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -214,21 +230,24 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
214
230
|
target: (B*T) where each value is in [0, V-1]
|
|
215
231
|
weight: (V, H) where V is the number of classes
|
|
216
232
|
bias: (V) where V is the number of classes
|
|
233
|
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
|
217
234
|
ignore_index: the index to ignore in the target
|
|
218
235
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
219
236
|
reduction: reduction to apply
|
|
220
237
|
"""
|
|
221
238
|
|
|
222
|
-
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
223
|
-
_input,
|
|
224
|
-
weight,
|
|
225
|
-
target,
|
|
226
|
-
bias,
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
239
|
+
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
240
|
+
_input=_input,
|
|
241
|
+
weight=weight,
|
|
242
|
+
target=target,
|
|
243
|
+
bias=bias,
|
|
244
|
+
ce_weight=ce_weight,
|
|
245
|
+
ignore_index=ignore_index,
|
|
246
|
+
lse_square_scale=lse_square_scale,
|
|
247
|
+
label_smoothing=label_smoothing,
|
|
248
|
+
reduction=reduction,
|
|
249
|
+
softcap=softcap,
|
|
250
|
+
return_z_loss=return_z_loss,
|
|
232
251
|
)
|
|
233
252
|
# downcast to dtype and store for backward
|
|
234
253
|
ctx.save_for_backward(
|
|
@@ -236,13 +255,28 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
236
255
|
grad_weight.detach() if grad_weight is not None else None,
|
|
237
256
|
grad_bias.detach() if bias is not None else None,
|
|
238
257
|
)
|
|
239
|
-
|
|
258
|
+
ctx.return_z_loss = return_z_loss
|
|
259
|
+
return loss, z_loss
|
|
240
260
|
|
|
241
261
|
@staticmethod
|
|
242
262
|
@amp_custom_bwd
|
|
243
|
-
def backward(ctx, grad_output):
|
|
263
|
+
def backward(ctx, grad_output, grad_output2):
|
|
264
|
+
if ctx.return_z_loss:
|
|
265
|
+
del grad_output2 # z_loss is only for logging
|
|
244
266
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
245
267
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
246
268
|
grad_output, grad_input, grad_weight, grad_bias
|
|
247
269
|
)
|
|
248
|
-
return (
|
|
270
|
+
return (
|
|
271
|
+
grad_input,
|
|
272
|
+
grad_weight,
|
|
273
|
+
None,
|
|
274
|
+
grad_bias,
|
|
275
|
+
None,
|
|
276
|
+
None,
|
|
277
|
+
None,
|
|
278
|
+
None,
|
|
279
|
+
None,
|
|
280
|
+
None,
|
|
281
|
+
None,
|
|
282
|
+
)
|