liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.4.dev20251121224847__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +133 -79
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +59 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +38 -6
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +31 -8
- liger_kernel/transformers/model/gemma3.py +100 -110
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1278 -116
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/METADATA +29 -24
- liger_kernel_nightly-0.6.4.dev20251121224847.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/top_level.txt +0 -0
|
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
32
|
loss_ptr,
|
|
33
33
|
z_loss_ptr,
|
|
34
34
|
loss_stride,
|
|
35
|
+
token_accuracy_ptr,
|
|
36
|
+
token_accuracy_stride,
|
|
35
37
|
n_cols,
|
|
36
38
|
n_non_ignore,
|
|
37
39
|
sum_non_ignore_weight,
|
|
@@ -42,9 +44,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
44
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
45
|
softcap,
|
|
44
46
|
RETURN_Z_LOSS: tl.constexpr,
|
|
47
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
48
|
BLOCK_SIZE: tl.constexpr,
|
|
46
49
|
HAS_WEIGHT: tl.constexpr,
|
|
47
50
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
51
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
52
|
):
|
|
49
53
|
"""
|
|
50
54
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +63,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
63
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
64
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
65
|
loss_stride (int): The stride of the loss tensor.
|
|
66
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
67
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
68
|
n_cols (int): The number of columns in the input tensor.
|
|
63
69
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
70
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +74,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
74
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
75
|
reduction (str): The string for the reduction to apply
|
|
70
76
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
77
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
78
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
79
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
80
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
81
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
82
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
83
|
"""
|
|
76
84
|
|
|
77
85
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +98,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
98
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
99
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
100
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
101
|
+
# For ignored tokens, set token accuracy to 0
|
|
102
|
+
if RETURN_TOKEN_ACCURACY:
|
|
103
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
104
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
105
|
return
|
|
94
106
|
|
|
95
107
|
loss_ptr += program_id * loss_stride
|
|
96
108
|
if RETURN_Z_LOSS:
|
|
97
109
|
z_loss_ptr += program_id * loss_stride
|
|
110
|
+
if RETURN_TOKEN_ACCURACY:
|
|
111
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
112
|
|
|
99
113
|
if HAS_WEIGHT:
|
|
100
114
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +119,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
119
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
120
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
121
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
122
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
123
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
124
|
if HAS_SOFTCAPPING:
|
|
110
125
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +140,16 @@ def liger_cross_entropy_kernel(
|
|
|
125
140
|
if HAS_SOFTCAPPING:
|
|
126
141
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
142
|
block_max = tl.max(X_block)
|
|
143
|
+
|
|
144
|
+
# Track argmax for accuracy computation
|
|
145
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
146
|
+
# Find the index of the maximum value in this block
|
|
147
|
+
is_max_mask = X_block == block_max
|
|
148
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
149
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
150
|
+
# Get the first (smallest) index where max occurs
|
|
151
|
+
argmax_idx = tl.min(masked_offsets)
|
|
152
|
+
|
|
128
153
|
if label_smoothing > 0:
|
|
129
154
|
# scale X beforehand to avoid overflow
|
|
130
155
|
if HAS_WEIGHT:
|
|
@@ -155,58 +180,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
180
|
# For 'sum' reduction, no normalization is applied:
|
|
156
181
|
# dx_y = softmax(x_y) - 1
|
|
157
182
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
183
|
+
if HAS_GRADIENTS:
|
|
184
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
185
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
186
|
+
X_block = tl.load(
|
|
187
|
+
X_ptr + X_offsets,
|
|
188
|
+
mask=X_offsets < n_cols,
|
|
189
|
+
other=float("-inf"),
|
|
190
|
+
# Ensure float32 precision for softmax calculation
|
|
191
|
+
).cast(tl.float32)
|
|
192
|
+
if HAS_SOFTCAPPING:
|
|
193
|
+
intermediate = tanh(X_block / softcap)
|
|
194
|
+
X_block = softcap * intermediate
|
|
195
|
+
|
|
196
|
+
if not HAS_WEIGHT:
|
|
197
|
+
# softmax(x_i)
|
|
198
|
+
X_block = tl.exp(X_block - m) / d
|
|
199
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
200
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
201
|
+
# smoothing term
|
|
202
|
+
X_block += -eps
|
|
203
|
+
# special handle dx_y
|
|
204
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
205
|
+
# reduction scale
|
|
206
|
+
if reduction == "mean":
|
|
207
|
+
X_block = X_block / n_non_ignore
|
|
208
|
+
else:
|
|
209
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
210
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
211
|
+
# derivative of original_loss
|
|
212
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
213
|
+
# specially handle dx_y
|
|
214
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
215
|
+
dloss_ori = dloss_ori * weight_y
|
|
216
|
+
# derivative of smooth_loss
|
|
217
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
218
|
+
# derivative of z-loss
|
|
219
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
220
|
+
# reduction scale
|
|
221
|
+
if reduction == "mean":
|
|
222
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
223
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
224
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
225
|
+
dz_loss = dz_loss / n_non_ignore
|
|
226
|
+
# derivative of total_loss
|
|
227
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
228
|
+
|
|
229
|
+
# chain rule softcapping
|
|
230
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
231
|
+
if HAS_SOFTCAPPING:
|
|
232
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
233
|
+
|
|
234
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
235
|
|
|
211
236
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
237
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,6 +279,10 @@ def liger_cross_entropy_kernel(
|
|
|
254
279
|
tl.store(loss_ptr, loss)
|
|
255
280
|
if RETURN_Z_LOSS:
|
|
256
281
|
tl.store(z_loss_ptr, z_loss)
|
|
282
|
+
if RETURN_TOKEN_ACCURACY:
|
|
283
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
284
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
285
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
286
|
|
|
258
287
|
|
|
259
288
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -272,8 +301,12 @@ def cross_entropy_forward(
|
|
|
272
301
|
reduction,
|
|
273
302
|
softcap,
|
|
274
303
|
return_z_loss,
|
|
304
|
+
return_token_accuracy=False,
|
|
275
305
|
):
|
|
276
306
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
307
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
308
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
309
|
+
)
|
|
277
310
|
|
|
278
311
|
BT, V = _input.shape
|
|
279
312
|
n_rows = BT
|
|
@@ -283,6 +316,9 @@ def cross_entropy_forward(
|
|
|
283
316
|
# unreduced loss
|
|
284
317
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
318
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
319
|
+
token_accuracy_1d = (
|
|
320
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
321
|
+
)
|
|
286
322
|
|
|
287
323
|
target_mask = target != ignore_index
|
|
288
324
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +355,10 @@ def cross_entropy_forward(
|
|
|
319
355
|
loss_ptr=loss_1d,
|
|
320
356
|
z_loss_ptr=z_loss_1d,
|
|
321
357
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
358
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
359
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
360
|
+
if return_token_accuracy
|
|
361
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
362
|
n_cols=V,
|
|
323
363
|
n_non_ignore=n_non_ignore,
|
|
324
364
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +369,11 @@ def cross_entropy_forward(
|
|
|
329
369
|
reduction=reduction,
|
|
330
370
|
softcap=softcap,
|
|
331
371
|
RETURN_Z_LOSS=return_z_loss,
|
|
372
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
373
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
374
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
375
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
376
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
377
|
# TODO: 32 seems to give the best performance
|
|
336
378
|
# Performance is quite sensitive to num_warps
|
|
337
379
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,11 +382,14 @@ def cross_entropy_forward(
|
|
|
340
382
|
if reduction == "none":
|
|
341
383
|
loss = loss_1d
|
|
342
384
|
z_loss = z_loss_1d if return_z_loss else None
|
|
385
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
386
|
else:
|
|
344
387
|
loss = torch.sum(loss_1d)
|
|
345
388
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
389
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
390
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
391
|
|
|
347
|
-
return loss, z_loss, _input
|
|
392
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
393
|
|
|
349
394
|
|
|
350
395
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -392,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
392
437
|
reduction: str = "mean",
|
|
393
438
|
softcap: Optional[float] = None,
|
|
394
439
|
return_z_loss: bool = False,
|
|
440
|
+
return_token_accuracy: bool = False,
|
|
395
441
|
):
|
|
396
442
|
"""
|
|
397
443
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -406,12 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
406
452
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
407
453
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
408
454
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
409
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
455
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
456
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
410
457
|
|
|
411
458
|
Returns:
|
|
412
|
-
tuple: A tuple with the
|
|
459
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
413
460
|
"""
|
|
414
|
-
|
|
461
|
+
input_requires_grad = _input.requires_grad
|
|
462
|
+
|
|
463
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
415
464
|
_input,
|
|
416
465
|
target,
|
|
417
466
|
weight,
|
|
@@ -421,29 +470,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
421
470
|
reduction,
|
|
422
471
|
softcap,
|
|
423
472
|
return_z_loss,
|
|
473
|
+
return_token_accuracy,
|
|
424
474
|
)
|
|
425
475
|
# TODO: investigation
|
|
426
476
|
# If we don't detach the _input tensor, the memory will double
|
|
427
477
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
478
|
+
if input_requires_grad:
|
|
479
|
+
ctx.save_for_backward(_input.detach())
|
|
429
480
|
ctx.return_z_loss = return_z_loss
|
|
481
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
430
482
|
|
|
431
|
-
return loss, z_loss
|
|
483
|
+
return loss, z_loss, token_accuracy
|
|
432
484
|
|
|
433
485
|
@staticmethod
|
|
434
|
-
def backward(ctx, grad_output,
|
|
486
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
435
487
|
"""
|
|
436
488
|
The backward pass of the Liger Cross Entropy loss.
|
|
437
489
|
|
|
438
490
|
Parameters:
|
|
439
491
|
ctx : The context object with saved tensors.
|
|
440
492
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
441
|
-
grad_output2 (
|
|
493
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
494
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
442
495
|
Returns:
|
|
443
496
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
444
497
|
"""
|
|
445
498
|
if ctx.return_z_loss:
|
|
446
|
-
del
|
|
499
|
+
del grad_output2 # z_loss is only for logging
|
|
500
|
+
if ctx.return_token_accuracy:
|
|
501
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
447
502
|
|
|
448
503
|
(_input,) = ctx.saved_tensors
|
|
449
504
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -457,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
457
512
|
None,
|
|
458
513
|
None,
|
|
459
514
|
None,
|
|
515
|
+
None,
|
|
460
516
|
)
|