liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +144 -65
- liger_kernel/ops/experimental/mm_int8int2.py +355 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
- liger_kernel/ops/fused_linear_jsd.py +245 -0
- liger_kernel/ops/geglu.py +2 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/jsd.py +176 -0
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/ops/rms_norm.py +92 -46
- liger_kernel/ops/swiglu.py +2 -2
- liger_kernel/ops/utils.py +62 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +38 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/fused_linear_jsd.py +98 -0
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/jsd.py +75 -0
- liger_kernel/transformers/model/gemma.py +124 -1
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/model/llama.py +135 -4
- liger_kernel/transformers/model/mistral.py +3 -0
- liger_kernel/transformers/model/mixtral.py +153 -2
- liger_kernel/transformers/model/mllama.py +274 -0
- liger_kernel/transformers/model/phi3.py +140 -2
- liger_kernel/transformers/model/qwen2.py +123 -2
- liger_kernel/transformers/model/qwen2_vl.py +8 -1
- liger_kernel/transformers/monkey_patch.py +258 -68
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
- liger_kernel-0.4.1.dist-info/NOTICE +58 -0
- liger_kernel-0.4.1.dist-info/RECORD +51 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
- liger_kernel-0.3.1.dist-info/NOTICE +0 -4
- liger_kernel-0.3.1.dist-info/RECORD +0 -42
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
liger_kernel/env_report.py
CHANGED
|
@@ -4,11 +4,13 @@ import sys
|
|
|
4
4
|
|
|
5
5
|
def print_env_report():
|
|
6
6
|
"""
|
|
7
|
+
|
|
7
8
|
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
8
9
|
Usage:
|
|
9
10
|
```
|
|
10
11
|
python -m liger_kernel.env_report
|
|
11
12
|
```
|
|
13
|
+
|
|
12
14
|
"""
|
|
13
15
|
print("Environment Report:")
|
|
14
16
|
print("-------------------")
|
|
@@ -1,7 +1,25 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import triton
|
|
3
6
|
import triton.language as tl
|
|
4
7
|
|
|
8
|
+
from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
|
|
9
|
+
|
|
10
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
try:
|
|
12
|
+
# typical import path with dispatch available
|
|
13
|
+
from triton.language.extra.libdevice import tanh
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
# for working with NGC containers
|
|
16
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
17
|
+
else:
|
|
18
|
+
from triton.language.math import tanh
|
|
19
|
+
|
|
20
|
+
_TRUE = tl.constexpr(1)
|
|
21
|
+
_FALSE = tl.constexpr(0)
|
|
22
|
+
|
|
5
23
|
|
|
6
24
|
@triton.jit
|
|
7
25
|
def liger_cross_entropy_kernel(
|
|
@@ -10,13 +28,18 @@ def liger_cross_entropy_kernel(
|
|
|
10
28
|
Y_ptr,
|
|
11
29
|
Y_stride,
|
|
12
30
|
loss_ptr,
|
|
31
|
+
z_loss_ptr,
|
|
13
32
|
loss_stride,
|
|
14
33
|
n_cols,
|
|
15
34
|
n_non_ignore,
|
|
16
35
|
ignore_index,
|
|
36
|
+
lse_square_scale: tl.constexpr,
|
|
17
37
|
label_smoothing: tl.constexpr,
|
|
18
38
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
39
|
+
softcap,
|
|
40
|
+
RETURN_Z_LOSS: tl.constexpr,
|
|
19
41
|
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
HAS_SOFTCAPPING: tl.constexpr,
|
|
20
43
|
):
|
|
21
44
|
"""
|
|
22
45
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -28,13 +51,18 @@ def liger_cross_entropy_kernel(
|
|
|
28
51
|
Y_ptr: Pointer to target tensor.
|
|
29
52
|
Y_stride (int): The stride of the target tensor.
|
|
30
53
|
loss_ptr: Pointer to tensor to store the loss.
|
|
54
|
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
31
55
|
loss_stride (int): The stride of the loss tensor.
|
|
32
56
|
n_cols (int): The number of columns in the input tensor.
|
|
33
57
|
n_non_ignore (int): The number of non-ignored elements in the batch.
|
|
34
58
|
ignore_index (int): The index to ignore in the target.
|
|
35
59
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
60
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
61
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
36
62
|
reduction (str): The string for the reduction to apply
|
|
63
|
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
37
64
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
65
|
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
38
66
|
"""
|
|
39
67
|
|
|
40
68
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -56,6 +84,7 @@ def liger_cross_entropy_kernel(
|
|
|
56
84
|
return
|
|
57
85
|
|
|
58
86
|
loss_ptr += program_id * loss_stride
|
|
87
|
+
z_loss_ptr += program_id * loss_stride
|
|
59
88
|
|
|
60
89
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
61
90
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
@@ -66,6 +95,8 @@ def liger_cross_entropy_kernel(
|
|
|
66
95
|
ori_X_y = tl.load(
|
|
67
96
|
X_ptr + y
|
|
68
97
|
) # we need to store the original value of X_y for the loss calculation
|
|
98
|
+
if HAS_SOFTCAPPING:
|
|
99
|
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
69
100
|
|
|
70
101
|
# Label smoothing is a general case of normal cross entropy
|
|
71
102
|
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
@@ -77,6 +108,8 @@ def liger_cross_entropy_kernel(
|
|
|
77
108
|
X_block = tl.load(
|
|
78
109
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
79
110
|
)
|
|
111
|
+
if HAS_SOFTCAPPING:
|
|
112
|
+
X_block = softcap * tanh(X_block / softcap)
|
|
80
113
|
block_max = tl.max(X_block)
|
|
81
114
|
if label_smoothing > 0:
|
|
82
115
|
# scale X beforehand to avoid overflow
|
|
@@ -85,32 +118,49 @@ def liger_cross_entropy_kernel(
|
|
|
85
118
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
86
119
|
m = m_new
|
|
87
120
|
|
|
121
|
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
|
122
|
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
|
123
|
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
|
124
|
+
lse = m + tl.log(d)
|
|
125
|
+
|
|
88
126
|
# 4. [Online Softmax] Second pass: compute gradients
|
|
89
127
|
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
|
90
128
|
# dx_y = (softmax(x_y) - 1) / N
|
|
91
129
|
# dx_i = softmax(x_i) / N, i != y
|
|
92
130
|
# For label smoothing:
|
|
93
|
-
# dx_i = (softmax(
|
|
131
|
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
|
94
132
|
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
|
95
133
|
# = dx_i - (1 - label_smoothing) / N
|
|
96
|
-
#
|
|
134
|
+
# With Z loss:
|
|
135
|
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
|
136
|
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
|
97
137
|
# For 'sum' reduction, no normalization is applied:
|
|
98
138
|
# dx_y = softmax(x_y) - 1
|
|
99
139
|
# dx_i = softmax(x_i), for i ≠ y
|
|
100
|
-
# For label smoothing:
|
|
101
|
-
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
|
|
102
|
-
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
|
|
103
|
-
# = dx_i - (1 - label_smoothing)
|
|
104
140
|
|
|
105
141
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
106
142
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
107
143
|
X_block = tl.load(
|
|
108
144
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
109
145
|
)
|
|
146
|
+
if HAS_SOFTCAPPING:
|
|
147
|
+
intermediate = tanh(X_block / softcap)
|
|
148
|
+
X_block = softcap * intermediate
|
|
149
|
+
# softmax(x_i)
|
|
150
|
+
X_block = tl.exp(X_block - m) / d
|
|
151
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
152
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
153
|
+
# smoothing term
|
|
154
|
+
X_block += -eps
|
|
155
|
+
# special handle dx_y
|
|
156
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
157
|
+
# reduction scale
|
|
110
158
|
if reduction == "mean":
|
|
111
|
-
X_block =
|
|
112
|
-
|
|
113
|
-
|
|
159
|
+
X_block = X_block / (n_non_ignore)
|
|
160
|
+
# chain rule
|
|
161
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
162
|
+
if HAS_SOFTCAPPING:
|
|
163
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
114
164
|
|
|
115
165
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
116
166
|
|
|
@@ -122,35 +172,35 @@ def liger_cross_entropy_kernel(
|
|
|
122
172
|
|
|
123
173
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
124
174
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
175
|
+
# = X_y - m - log d = X_y - lse
|
|
125
176
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
126
177
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
127
|
-
loss =
|
|
178
|
+
loss = lse - ori_X_y
|
|
128
179
|
|
|
129
|
-
#
|
|
180
|
+
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
130
181
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
131
182
|
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
132
183
|
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
133
|
-
# = (1 - label_smoothing) * H(q, p) + (
|
|
184
|
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
|
134
185
|
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
135
186
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
136
187
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
137
188
|
if label_smoothing > 0:
|
|
138
|
-
smooth_loss = scaled_x_sum + label_smoothing *
|
|
189
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
139
190
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
140
191
|
|
|
192
|
+
# An auxiliary loss, z_loss
|
|
193
|
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
194
|
+
z_loss = lse_square_scale * lse * lse
|
|
195
|
+
loss += z_loss
|
|
141
196
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
142
197
|
if reduction == "mean":
|
|
198
|
+
z_loss = z_loss / n_non_ignore
|
|
143
199
|
loss = loss / n_non_ignore
|
|
144
200
|
|
|
145
|
-
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
146
|
-
X_y = tl.load(X_ptr + y)
|
|
147
|
-
if reduction == "mean":
|
|
148
|
-
X_y += -(1 - label_smoothing) / (n_non_ignore)
|
|
149
|
-
else:
|
|
150
|
-
X_y += -(1 - label_smoothing)
|
|
151
|
-
|
|
152
201
|
tl.store(loss_ptr, loss)
|
|
153
|
-
|
|
202
|
+
if RETURN_Z_LOSS == _TRUE:
|
|
203
|
+
tl.store(z_loss_ptr, z_loss)
|
|
154
204
|
|
|
155
205
|
|
|
156
206
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -159,43 +209,32 @@ def liger_cross_entropy_kernel(
|
|
|
159
209
|
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
160
210
|
|
|
161
211
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
grad_output_ptr,
|
|
167
|
-
n_cols,
|
|
168
|
-
BLOCK_SIZE: tl.constexpr,
|
|
169
|
-
):
|
|
170
|
-
"""
|
|
171
|
-
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
|
|
172
|
-
The multiplication is performed in-place on the tensor pointed by X_ptr.
|
|
173
|
-
|
|
174
|
-
Parameters:
|
|
175
|
-
X_ptr: Pointer to the input tensor.
|
|
176
|
-
X_stride (int): The stride of the input tensor.
|
|
177
|
-
grad_output_ptr: Pointer to the gradient output value.
|
|
178
|
-
n_cols (int): The number of columns in the input tensor.
|
|
179
|
-
BLOCK_SIZE (int): The block size for Triton operations.
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
# Get the program ID and convert it to int64 to avoid overflow
|
|
183
|
-
program_id = tl.program_id(0).to(tl.int64)
|
|
184
|
-
|
|
185
|
-
# Locate the start index
|
|
186
|
-
X_ptr += program_id * X_stride
|
|
212
|
+
_bool_to_return_z_loss = {
|
|
213
|
+
True: _TRUE.value,
|
|
214
|
+
False: _FALSE.value,
|
|
215
|
+
}
|
|
187
216
|
|
|
188
|
-
# Load the gradient output value
|
|
189
|
-
grad_output = tl.load(grad_output_ptr)
|
|
190
|
-
|
|
191
|
-
# Perform the element-wise multiplication
|
|
192
|
-
for i in range(0, n_cols, BLOCK_SIZE):
|
|
193
|
-
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
194
|
-
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
195
|
-
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
196
217
|
|
|
218
|
+
def cross_entropy_forward(
|
|
219
|
+
_input,
|
|
220
|
+
target,
|
|
221
|
+
ignore_index,
|
|
222
|
+
lse_square_scale,
|
|
223
|
+
label_smoothing,
|
|
224
|
+
reduction,
|
|
225
|
+
softcap,
|
|
226
|
+
return_z_loss,
|
|
227
|
+
):
|
|
228
|
+
if not isinstance(return_z_loss, int):
|
|
229
|
+
assert (
|
|
230
|
+
return_z_loss in _bool_to_return_z_loss
|
|
231
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
232
|
+
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
233
|
+
else:
|
|
234
|
+
assert (
|
|
235
|
+
return_z_loss in _bool_to_return_z_loss
|
|
236
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
197
237
|
|
|
198
|
-
def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
|
|
199
238
|
BT, V = _input.shape
|
|
200
239
|
n_rows = BT
|
|
201
240
|
|
|
@@ -203,6 +242,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
203
242
|
|
|
204
243
|
# unreduced loss
|
|
205
244
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
245
|
+
if return_z_loss == _TRUE.value:
|
|
246
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
247
|
+
else:
|
|
248
|
+
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
|
|
206
249
|
|
|
207
250
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
208
251
|
|
|
@@ -219,20 +262,30 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
219
262
|
Y_ptr=target,
|
|
220
263
|
Y_stride=target.stride(-1), # always 1
|
|
221
264
|
loss_ptr=loss_1d,
|
|
265
|
+
z_loss_ptr=z_loss_1d,
|
|
222
266
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
223
267
|
n_cols=V,
|
|
224
268
|
n_non_ignore=n_non_ignore,
|
|
225
269
|
ignore_index=ignore_index,
|
|
270
|
+
lse_square_scale=lse_square_scale,
|
|
226
271
|
label_smoothing=label_smoothing,
|
|
227
272
|
reduction=reduction,
|
|
273
|
+
softcap=softcap if softcap is not None else 0.0,
|
|
274
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
228
275
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
276
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
229
277
|
# TODO: 32 seems to give the best performance
|
|
230
278
|
# Performance is quite sensitive to num_warps
|
|
231
|
-
num_warps=32,
|
|
279
|
+
num_warps=32 if not is_hip() else 16,
|
|
232
280
|
)
|
|
233
281
|
|
|
234
282
|
loss = torch.sum(loss_1d)
|
|
235
|
-
|
|
283
|
+
if return_z_loss == _TRUE.value:
|
|
284
|
+
z_loss = torch.sum(z_loss_1d)
|
|
285
|
+
else:
|
|
286
|
+
z_loss = None
|
|
287
|
+
|
|
288
|
+
return loss, z_loss, _input
|
|
236
289
|
|
|
237
290
|
|
|
238
291
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -253,7 +306,7 @@ def cross_entropy_backward(_input, grad_output):
|
|
|
253
306
|
grad_output,
|
|
254
307
|
V,
|
|
255
308
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
256
|
-
num_warps=32,
|
|
309
|
+
num_warps=32 if not is_hip() else 16,
|
|
257
310
|
)
|
|
258
311
|
|
|
259
312
|
return _input
|
|
@@ -267,7 +320,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
267
320
|
|
|
268
321
|
@staticmethod
|
|
269
322
|
def forward(
|
|
270
|
-
ctx,
|
|
323
|
+
ctx,
|
|
324
|
+
_input: torch.Tensor,
|
|
325
|
+
target: torch.Tensor,
|
|
326
|
+
ignore_index: int = -100,
|
|
327
|
+
lse_square_scale: float = 0.0,
|
|
328
|
+
label_smoothing: float = 0.0,
|
|
329
|
+
reduction: str = "mean",
|
|
330
|
+
softcap: Optional[float] = None,
|
|
331
|
+
return_z_loss: bool = False,
|
|
271
332
|
):
|
|
272
333
|
"""
|
|
273
334
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -277,33 +338,48 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
277
338
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
278
339
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
279
340
|
ignore_index (int): The index to ignore in the target.
|
|
341
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
280
342
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
281
343
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
344
|
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
345
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
282
346
|
|
|
283
347
|
Returns:
|
|
284
|
-
|
|
348
|
+
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
|
285
349
|
"""
|
|
286
|
-
loss, _input = cross_entropy_forward(
|
|
287
|
-
_input,
|
|
350
|
+
loss, z_loss, _input = cross_entropy_forward(
|
|
351
|
+
_input,
|
|
352
|
+
target,
|
|
353
|
+
ignore_index,
|
|
354
|
+
lse_square_scale,
|
|
355
|
+
label_smoothing,
|
|
356
|
+
reduction,
|
|
357
|
+
softcap,
|
|
358
|
+
return_z_loss,
|
|
288
359
|
)
|
|
289
360
|
# TODO: investigation
|
|
290
361
|
# If we don't detach the _input tensor, the memory will double
|
|
291
362
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
292
363
|
ctx.save_for_backward(_input.detach())
|
|
293
|
-
|
|
364
|
+
ctx.return_z_loss = return_z_loss
|
|
365
|
+
|
|
366
|
+
return loss, z_loss
|
|
294
367
|
|
|
295
368
|
@staticmethod
|
|
296
|
-
def backward(ctx, grad_output):
|
|
369
|
+
def backward(ctx, grad_output, grad_ouput2):
|
|
297
370
|
"""
|
|
298
371
|
The backward pass of the Liger Cross Entropy loss.
|
|
299
372
|
|
|
300
373
|
Parameters:
|
|
301
374
|
ctx : The context object with saved tensors.
|
|
302
375
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
303
|
-
|
|
376
|
+
grad_output2 (tenosr): No use.
|
|
304
377
|
Returns:
|
|
305
378
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
306
379
|
"""
|
|
380
|
+
if ctx.return_z_loss:
|
|
381
|
+
del grad_ouput2 # z_loss is only for logging
|
|
382
|
+
|
|
307
383
|
(_input,) = ctx.saved_tensors
|
|
308
384
|
_input = cross_entropy_backward(_input, grad_output)
|
|
309
385
|
return (
|
|
@@ -312,4 +388,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
312
388
|
None,
|
|
313
389
|
None,
|
|
314
390
|
None,
|
|
391
|
+
None,
|
|
392
|
+
None,
|
|
393
|
+
None,
|
|
315
394
|
)
|