liger-kernel-nightly 0.6.2.dev20251014053719__py3-none-any.whl → 0.6.2.dev20251016055812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +14 -10
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251014053719.dist-info → liger_kernel_nightly-0.6.2.dev20251016055812.dist-info}/top_level.txt +0 -0
@@ -414,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
414
414
|
Returns:
|
415
415
|
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
416
416
|
"""
|
417
|
+
input_requires_grad = _input.requires_grad
|
418
|
+
|
417
419
|
loss, z_loss, _input = cross_entropy_forward(
|
418
420
|
_input,
|
419
421
|
target,
|
@@ -428,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
428
430
|
# TODO: investigation
|
429
431
|
# If we don't detach the _input tensor, the memory will double
|
430
432
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
431
|
-
|
433
|
+
if input_requires_grad:
|
434
|
+
ctx.save_for_backward(_input.detach())
|
432
435
|
ctx.return_z_loss = return_z_loss
|
433
436
|
|
434
437
|
return loss, z_loss
|
@@ -31,6 +31,8 @@ def fused_linear_cross_entropy_forward(
|
|
31
31
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
32
32
|
device = _input.device
|
33
33
|
|
34
|
+
input_requires_grad = _input.requires_grad
|
35
|
+
|
34
36
|
# inputs have shape: BT x H
|
35
37
|
# materialized activations will have shape: BT x V
|
36
38
|
# the increase in memory = BT x V
|
@@ -49,12 +51,13 @@ def fused_linear_cross_entropy_forward(
|
|
49
51
|
grad_input = torch.zeros_like(_input, device=device)
|
50
52
|
|
51
53
|
# we use fp32 for loss and gradients accumulator
|
52
|
-
if
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
54
|
+
if input_requires_grad:
|
55
|
+
if accum_dtype is None:
|
56
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
57
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
58
|
+
else:
|
59
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
60
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
58
61
|
|
59
62
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
60
63
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
@@ -150,7 +153,7 @@ def fused_linear_cross_entropy_forward(
|
|
150
153
|
RETURN_Z_LOSS=return_z_loss,
|
151
154
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
152
155
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
153
|
-
HAS_GRADIENTS=
|
156
|
+
HAS_GRADIENTS=input_requires_grad,
|
154
157
|
BLOCK_SIZE=BLOCK_SIZE,
|
155
158
|
num_warps=32 if not is_hip() else 16,
|
156
159
|
)
|
@@ -172,12 +175,13 @@ def fused_linear_cross_entropy_forward(
|
|
172
175
|
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
173
176
|
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
174
177
|
|
175
|
-
|
178
|
+
if input_requires_grad:
|
179
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
176
180
|
|
177
|
-
if grad_weight is not None and
|
181
|
+
if grad_weight is not None and input_requires_grad:
|
178
182
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
179
183
|
|
180
|
-
if bias is not None and
|
184
|
+
if bias is not None and input_requires_grad:
|
181
185
|
torch.add(
|
182
186
|
input=grad_bias,
|
183
187
|
other=grad_logits_chunk.sum(dim=0),
|
@@ -17,10 +17,10 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
|
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
18
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
19
19
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
20
|
+
liger_kernel/ops/cross_entropy.py,sha256=CEgAeX97ezIBRhK3dPQRKsEQiwgnBDOewtDoqKXzw_Q,19605
|
21
21
|
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
22
22
|
liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
|
23
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
23
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=rL6PyM4_9CLj7OL6qHa_ssFdJn0JEZlE12znF7T5cvM,14521
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
26
26
|
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
@@ -101,9 +101,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
101
101
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
102
102
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
103
103
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
104
|
-
liger_kernel_nightly-0.6.2.
|
105
|
-
liger_kernel_nightly-0.6.2.
|
106
|
-
liger_kernel_nightly-0.6.2.
|
107
|
-
liger_kernel_nightly-0.6.2.
|
108
|
-
liger_kernel_nightly-0.6.2.
|
109
|
-
liger_kernel_nightly-0.6.2.
|
104
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
105
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/METADATA,sha256=0T7yuosaQopminlzrQ4Z2ZyY7Lm_Dst67jQScbOIlHU,24777
|
106
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
107
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
108
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
109
|
+
liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|