liger-kernel-nightly 0.5.2.dev20241228022953__py3-none-any.whl → 0.5.2.dev20241229131950__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +88 -19
- liger_kernel/ops/fused_linear_cross_entropy.py +54 -30
- liger_kernel/transformers/cross_entropy.py +3 -0
- liger_kernel/transformers/functional.py +3 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -0
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/RECORD +11 -11
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241228022953.dist-info → liger_kernel_nightly-0.5.2.dev20241229131950.dist-info}/top_level.txt +0 -0
@@ -30,11 +30,14 @@ def liger_cross_entropy_kernel(
|
|
30
30
|
X_stride,
|
31
31
|
Y_ptr,
|
32
32
|
Y_stride,
|
33
|
+
weight_ptr,
|
33
34
|
loss_ptr,
|
34
35
|
z_loss_ptr,
|
35
36
|
loss_stride,
|
36
37
|
n_cols,
|
37
38
|
n_non_ignore,
|
39
|
+
sum_non_ignore_weight,
|
40
|
+
weight_sum,
|
38
41
|
ignore_index,
|
39
42
|
lse_square_scale: tl.constexpr,
|
40
43
|
label_smoothing: tl.constexpr,
|
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
42
45
|
softcap,
|
43
46
|
RETURN_Z_LOSS: tl.constexpr,
|
44
47
|
BLOCK_SIZE: tl.constexpr,
|
48
|
+
HAS_WEIGHT: tl.constexpr,
|
45
49
|
HAS_SOFTCAPPING: tl.constexpr,
|
46
50
|
):
|
47
51
|
"""
|
@@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
|
|
53
57
|
X_stride (int): The stride of the input tensor.
|
54
58
|
Y_ptr: Pointer to target tensor.
|
55
59
|
Y_stride (int): The stride of the target tensor.
|
60
|
+
weight_ptr: Pointer to weight tensor.
|
56
61
|
loss_ptr: Pointer to tensor to store the loss.
|
57
62
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
58
63
|
loss_stride (int): The stride of the loss tensor.
|
59
64
|
n_cols (int): The number of columns in the input tensor.
|
60
|
-
n_non_ignore (
|
65
|
+
n_non_ignore (flaot): The number of non-ignored elements in the batch.
|
66
|
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
67
|
+
weight_sum (float): The sum of weight tensor.
|
61
68
|
ignore_index (int): The index to ignore in the target.
|
62
69
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
63
70
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
64
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
65
71
|
reduction (str): The string for the reduction to apply
|
66
72
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
73
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
67
74
|
BLOCK_SIZE (int): The block size for Triton operations.
|
75
|
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
68
76
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
69
77
|
"""
|
70
78
|
|
@@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
|
|
89
97
|
loss_ptr += program_id * loss_stride
|
90
98
|
z_loss_ptr += program_id * loss_stride
|
91
99
|
|
100
|
+
if HAS_WEIGHT:
|
101
|
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
102
|
+
|
92
103
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
93
104
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
94
105
|
|
@@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
|
|
117
128
|
block_max = tl.max(X_block)
|
118
129
|
if label_smoothing > 0:
|
119
130
|
# scale X beforehand to avoid overflow
|
120
|
-
|
131
|
+
if HAS_WEIGHT:
|
132
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
133
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
134
|
+
else:
|
135
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
121
136
|
m_new = tl.maximum(m, block_max)
|
122
137
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
123
138
|
m = m_new
|
@@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
|
|
153
168
|
if HAS_SOFTCAPPING:
|
154
169
|
intermediate = tanh(X_block / softcap)
|
155
170
|
X_block = softcap * intermediate
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
171
|
+
|
172
|
+
if not HAS_WEIGHT:
|
173
|
+
# softmax(x_i)
|
174
|
+
X_block = tl.exp(X_block - m) / d
|
175
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
176
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
177
|
+
# smoothing term
|
178
|
+
X_block += -eps
|
179
|
+
# special handle dx_y
|
180
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
181
|
+
# reduction scale
|
182
|
+
if reduction == "mean":
|
183
|
+
X_block = X_block / n_non_ignore
|
184
|
+
else:
|
185
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
186
|
+
softmax_X = tl.exp(X_block - m) / d
|
187
|
+
# derivative of original_loss
|
188
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
189
|
+
# specially handle dx_y
|
190
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
191
|
+
dloss_ori = dloss_ori * weight_y
|
192
|
+
# derivative of smooth_loss
|
193
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
194
|
+
# derivative of z-loss
|
195
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
196
|
+
# reduction scale
|
197
|
+
if reduction == "mean":
|
198
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
199
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
200
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
201
|
+
dz_loss = dz_loss / n_non_ignore
|
202
|
+
# derivative of total_loss
|
203
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
204
|
+
|
205
|
+
# chain rule softcapping
|
168
206
|
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
169
207
|
if HAS_SOFTCAPPING:
|
170
208
|
X_block = X_block * (1 - intermediate * intermediate)
|
@@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
|
|
183
221
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
184
222
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
185
223
|
loss = lse - ori_X_y
|
224
|
+
if HAS_WEIGHT:
|
225
|
+
loss = weight_y * loss
|
186
226
|
|
187
227
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
188
228
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
@@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
|
|
193
233
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
194
234
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
195
235
|
if label_smoothing > 0:
|
196
|
-
|
236
|
+
if HAS_WEIGHT:
|
237
|
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
238
|
+
else:
|
239
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
197
240
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
198
241
|
|
199
242
|
# An auxiliary loss, z_loss
|
200
243
|
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
201
244
|
z_loss = lse_square_scale * lse * lse
|
202
|
-
loss += z_loss
|
203
245
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
204
246
|
if reduction == "mean":
|
247
|
+
if HAS_WEIGHT:
|
248
|
+
loss = loss / sum_non_ignore_weight
|
249
|
+
else:
|
250
|
+
loss = loss / n_non_ignore
|
251
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
205
252
|
z_loss = z_loss / n_non_ignore
|
206
|
-
|
253
|
+
loss += z_loss
|
207
254
|
|
208
255
|
tl.store(loss_ptr, loss)
|
209
256
|
if RETURN_Z_LOSS == _TRUE:
|
@@ -225,6 +272,7 @@ _bool_to_return_z_loss = {
|
|
225
272
|
def cross_entropy_forward(
|
226
273
|
_input,
|
227
274
|
target,
|
275
|
+
weight,
|
228
276
|
ignore_index,
|
229
277
|
lse_square_scale,
|
230
278
|
label_smoothing,
|
@@ -250,7 +298,20 @@ def cross_entropy_forward(
|
|
250
298
|
else:
|
251
299
|
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
|
252
300
|
|
253
|
-
|
301
|
+
target_mask = target != ignore_index
|
302
|
+
n_non_ignore = target_mask.sum().item()
|
303
|
+
sum_non_ignore_weight = n_non_ignore
|
304
|
+
weight_sum = 0.0
|
305
|
+
if weight is not None:
|
306
|
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
307
|
+
assert torch.is_floating_point(
|
308
|
+
weight
|
309
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
310
|
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
311
|
+
weight_sum = weight.sum().item()
|
312
|
+
# ensure weight is contiguous
|
313
|
+
if weight.stride(-1) != 1:
|
314
|
+
weight = weight.contiguous()
|
254
315
|
|
255
316
|
# ensure _input and target are contiguous in the last dimension
|
256
317
|
if _input.stride(-1) != 1:
|
@@ -264,18 +325,22 @@ def cross_entropy_forward(
|
|
264
325
|
X_stride=_input.stride(-2),
|
265
326
|
Y_ptr=target,
|
266
327
|
Y_stride=target.stride(-1), # always 1
|
328
|
+
weight_ptr=weight if weight is not None else _input, # dummy if None
|
267
329
|
loss_ptr=loss_1d,
|
268
330
|
z_loss_ptr=z_loss_1d,
|
269
331
|
loss_stride=loss_1d.stride(-1), # always 1
|
270
332
|
n_cols=V,
|
271
333
|
n_non_ignore=n_non_ignore,
|
334
|
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
272
335
|
ignore_index=ignore_index,
|
336
|
+
weight_sum=weight_sum,
|
273
337
|
lse_square_scale=lse_square_scale,
|
274
338
|
label_smoothing=label_smoothing,
|
275
339
|
reduction=reduction,
|
276
340
|
softcap=softcap if softcap is not None else 0.0,
|
277
341
|
RETURN_Z_LOSS=return_z_loss,
|
278
342
|
BLOCK_SIZE=BLOCK_SIZE,
|
343
|
+
HAS_WEIGHT=True if weight is not None else False,
|
279
344
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
280
345
|
# TODO: 32 seems to give the best performance
|
281
346
|
# Performance is quite sensitive to num_warps
|
@@ -327,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
327
392
|
ctx,
|
328
393
|
_input: torch.Tensor,
|
329
394
|
target: torch.Tensor,
|
395
|
+
weight: Optional[torch.FloatTensor],
|
330
396
|
ignore_index: int = -100,
|
331
397
|
lse_square_scale: float = 0.0,
|
332
398
|
label_smoothing: float = 0.0,
|
@@ -341,6 +407,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
341
407
|
ctx : The context object.
|
342
408
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
343
409
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
410
|
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
344
411
|
ignore_index (int): The index to ignore in the target.
|
345
412
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
346
413
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
@@ -354,6 +421,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
354
421
|
loss, z_loss, _input = cross_entropy_forward(
|
355
422
|
_input,
|
356
423
|
target,
|
424
|
+
weight,
|
357
425
|
ignore_index,
|
358
426
|
lse_square_scale,
|
359
427
|
label_smoothing,
|
@@ -395,4 +463,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
395
463
|
None,
|
396
464
|
None,
|
397
465
|
None,
|
466
|
+
None,
|
398
467
|
)
|
@@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward(
|
|
17
17
|
_input,
|
18
18
|
weight,
|
19
19
|
target,
|
20
|
+
ce_weight=None,
|
20
21
|
bias=None,
|
21
22
|
ignore_index=-100,
|
22
23
|
lse_square_scale=0.0,
|
@@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward(
|
|
47
48
|
# we use fp32 for loss accumulator
|
48
49
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
49
50
|
|
50
|
-
#
|
51
|
-
|
51
|
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
52
|
+
target_mask = target != ignore_index
|
53
|
+
total_n_non_ignore = target_mask.sum().item()
|
54
|
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
55
|
+
ce_weight_sum = 0.0
|
56
|
+
if ce_weight is not None:
|
57
|
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
58
|
+
assert torch.is_floating_point(
|
59
|
+
ce_weight
|
60
|
+
), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
61
|
+
total_sum_non_ignore_ce_weight = (
|
62
|
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
63
|
+
)
|
64
|
+
ce_weight_sum = ce_weight.sum().item()
|
65
|
+
if ce_weight.stride(-1) != 1:
|
66
|
+
ce_weight = ce_weight.contiguous()
|
52
67
|
|
53
68
|
for chunk_id in range(num_chunks):
|
54
69
|
start_idx = chunk_id * chunk_size
|
@@ -59,13 +74,13 @@ def fused_linear_cross_entropy_forward(
|
|
59
74
|
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
60
75
|
if bias is not None:
|
61
76
|
logits_chunk = logits_chunk + bias
|
77
|
+
|
62
78
|
target_chunk = target[start_idx:end_idx] # chunk_size,
|
63
79
|
|
64
80
|
n_rows = logits_chunk.shape[0]
|
65
81
|
|
66
82
|
# unreduced loss
|
67
83
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
68
|
-
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
69
84
|
|
70
85
|
# ensure _input and target are contiguous
|
71
86
|
logits_chunk = logits_chunk.contiguous()
|
@@ -77,45 +92,40 @@ def fused_linear_cross_entropy_forward(
|
|
77
92
|
X_stride=logits_chunk.stride(-2),
|
78
93
|
Y_ptr=target_chunk,
|
79
94
|
Y_stride=target_chunk.stride(-1), # always 1
|
95
|
+
weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
|
80
96
|
loss_ptr=loss_1d_slice,
|
81
97
|
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
|
82
98
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
83
99
|
n_cols=V,
|
84
|
-
n_non_ignore=
|
100
|
+
n_non_ignore=total_n_non_ignore,
|
101
|
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
102
|
+
weight_sum=ce_weight_sum,
|
85
103
|
ignore_index=ignore_index,
|
86
104
|
lse_square_scale=lse_square_scale,
|
87
105
|
label_smoothing=label_smoothing,
|
88
106
|
reduction=reduction,
|
89
107
|
softcap=softcap if softcap is not None else 0.0,
|
90
108
|
RETURN_Z_LOSS=0, # False
|
109
|
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
91
110
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
92
111
|
BLOCK_SIZE=BLOCK_SIZE,
|
93
112
|
num_warps=32 if not is_hip() else 16,
|
94
113
|
)
|
95
114
|
|
96
|
-
|
97
|
-
|
98
|
-
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
99
|
-
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
|
100
|
-
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
101
|
-
|
102
|
-
if reduction == "mean":
|
103
|
-
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
|
104
|
-
else:
|
105
|
-
alpha = 1.0
|
106
|
-
|
107
|
-
loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
|
108
|
-
grad_logits_chunk = logits_chunk * alpha # chunk_size x V
|
115
|
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
116
|
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
109
117
|
|
110
118
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
111
119
|
|
112
120
|
if grad_weight is not None:
|
113
121
|
torch.addmm(
|
114
122
|
input=grad_weight,
|
115
|
-
mat1=logits_chunk.t()
|
123
|
+
mat1=logits_chunk.t().to(
|
124
|
+
_input_chunk.dtype
|
125
|
+
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
116
126
|
mat2=_input_chunk,
|
117
127
|
out=grad_weight,
|
118
|
-
alpha=
|
128
|
+
alpha=1.0,
|
119
129
|
beta=1.0,
|
120
130
|
)
|
121
131
|
|
@@ -124,7 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
124
134
|
input=grad_bias,
|
125
135
|
other=logits_chunk.sum(dim=0),
|
126
136
|
out=grad_bias,
|
127
|
-
alpha=
|
137
|
+
alpha=1.0,
|
128
138
|
)
|
129
139
|
|
130
140
|
if reduction == "none":
|
@@ -190,6 +200,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
190
200
|
weight,
|
191
201
|
target,
|
192
202
|
bias=None,
|
203
|
+
ce_weight=None,
|
193
204
|
ignore_index=-100,
|
194
205
|
lse_square_scale=0.0,
|
195
206
|
label_smoothing=0.0,
|
@@ -209,21 +220,23 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
209
220
|
target: (B*T) where each value is in [0, V-1]
|
210
221
|
weight: (V, H) where V is the number of classes
|
211
222
|
bias: (V) where V is the number of classes
|
223
|
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
212
224
|
ignore_index: the index to ignore in the target
|
213
225
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
214
226
|
reduction: reduction to apply
|
215
227
|
"""
|
216
228
|
|
217
229
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
218
|
-
_input,
|
219
|
-
weight,
|
220
|
-
target,
|
221
|
-
bias,
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
230
|
+
_input=_input,
|
231
|
+
weight=weight,
|
232
|
+
target=target,
|
233
|
+
bias=bias,
|
234
|
+
ce_weight=ce_weight,
|
235
|
+
ignore_index=ignore_index,
|
236
|
+
lse_square_scale=lse_square_scale,
|
237
|
+
label_smoothing=label_smoothing,
|
238
|
+
reduction=reduction,
|
239
|
+
softcap=softcap,
|
227
240
|
)
|
228
241
|
# downcast to dtype and store for backward
|
229
242
|
ctx.save_for_backward(
|
@@ -240,4 +253,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
240
253
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
241
254
|
grad_output, grad_input, grad_weight, grad_bias
|
242
255
|
)
|
243
|
-
return (
|
256
|
+
return (
|
257
|
+
grad_input,
|
258
|
+
grad_weight,
|
259
|
+
None,
|
260
|
+
grad_bias,
|
261
|
+
None,
|
262
|
+
None,
|
263
|
+
None,
|
264
|
+
None,
|
265
|
+
None,
|
266
|
+
None,
|
267
|
+
)
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
8
8
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
+
weight: Optional[torch.FloatTensor] = None,
|
11
12
|
ignore_index: int = -100,
|
12
13
|
lse_square_scale: float = 0.0,
|
13
14
|
label_smoothing: float = 0.0,
|
@@ -28,6 +29,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
28
29
|
"none",
|
29
30
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
30
31
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
32
|
+
self.weight = weight
|
31
33
|
self.ignore_index = ignore_index
|
32
34
|
self.lse_square_scale = lse_square_scale
|
33
35
|
self.label_smoothing = label_smoothing
|
@@ -39,6 +41,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
39
41
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
40
42
|
_input,
|
41
43
|
target,
|
44
|
+
self.weight,
|
42
45
|
self.ignore_index,
|
43
46
|
self.lse_square_scale,
|
44
47
|
self.label_smoothing,
|
@@ -32,6 +32,7 @@ def liger_cross_entropy(
|
|
32
32
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
33
33
|
input,
|
34
34
|
target,
|
35
|
+
weight,
|
35
36
|
ignore_index,
|
36
37
|
lse_square_scale,
|
37
38
|
label_smoothing,
|
@@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy(
|
|
49
50
|
weight,
|
50
51
|
target,
|
51
52
|
bias=None,
|
53
|
+
ce_weight=None,
|
52
54
|
ignore_index: int = -100,
|
53
55
|
lse_square_scale: float = 0.0,
|
54
56
|
label_smoothing: float = 0.0,
|
@@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy(
|
|
60
62
|
weight,
|
61
63
|
target,
|
62
64
|
bias,
|
65
|
+
ce_weight,
|
63
66
|
ignore_index,
|
64
67
|
lse_square_scale,
|
65
68
|
label_smoothing,
|
@@ -8,6 +8,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEnt
|
|
8
8
|
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
+
ce_weight: Optional[torch.FloatTensor] = None,
|
11
12
|
ignore_index: int = -100,
|
12
13
|
lse_square_scale: float = 0.0,
|
13
14
|
label_smoothing: float = 0.0,
|
@@ -24,6 +25,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
24
25
|
"none",
|
25
26
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
26
27
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
28
|
+
self.ce_weight = ce_weight
|
27
29
|
self.ignore_index = ignore_index
|
28
30
|
self.lse_square_scale = lse_square_scale
|
29
31
|
self.label_smoothing = label_smoothing
|
@@ -36,6 +38,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
36
38
|
lin_weight,
|
37
39
|
target,
|
38
40
|
bias,
|
41
|
+
self.ce_weight,
|
39
42
|
self.ignore_index,
|
40
43
|
self.lse_square_scale,
|
41
44
|
self.label_smoothing,
|
@@ -11,8 +11,8 @@ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJ
|
|
11
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=4zSPzdPl-d2tB3ZOj7uRMpzI4RzZMNLUzkh6eMkH5kU,19179
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
17
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
18
18
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
@@ -28,9 +28,9 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
|
|
28
28
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
29
29
|
liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
|
30
30
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
31
|
-
liger_kernel/transformers/cross_entropy.py,sha256=
|
32
|
-
liger_kernel/transformers/functional.py,sha256=
|
33
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=
|
31
|
+
liger_kernel/transformers/cross_entropy.py,sha256=s931h9UW_tV4QMRme1HYjS_R2_C5nD6VFmZIXtjJoYo,1840
|
32
|
+
liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
|
33
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
|
34
34
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
35
35
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
36
36
|
liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/METADATA,sha256=iOyPsdNf1GL3Z3Ng0CS3xoOq6iiTb8eFXAMwqDT1UZM,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|