liger-kernel-nightly 0.5.8.dev20250428050809__py3-none-any.whl → 0.5.8.dev20250429233059__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.8.dev20250428050809.dist-info → liger_kernel_nightly-0.5.8.dev20250429233059.dist-info}/top_level.txt +0 -0
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
|
|
351
351
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
352
352
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
353
353
|
pass
|
354
|
-
|
354
|
+
# If reduction is 'none'
|
355
|
+
elif grad_output.ndim > 0:
|
356
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
357
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
355
358
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
356
359
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
357
360
|
else:
|
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
|
|
143
143
|
alpha=1.0,
|
144
144
|
)
|
145
145
|
|
146
|
-
if reduction
|
147
|
-
|
148
|
-
|
146
|
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
147
|
+
# if reduction == "none":
|
148
|
+
# loss = loss_1d
|
149
|
+
# z_loss = z_loss_1d if return_z_loss else None
|
149
150
|
|
150
151
|
else:
|
151
152
|
loss = torch.sum(loss_1d)
|
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
23
23
|
assert reduction in {
|
24
24
|
"mean",
|
25
25
|
"sum",
|
26
|
-
|
27
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
26
|
+
}, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
|
28
27
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
29
28
|
self.ce_weight = ce_weight
|
30
29
|
self.ignore_index = ignore_index
|
@@ -16,9 +16,9 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
|
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
19
|
+
liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
|
20
20
|
liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
|
21
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
21
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
23
23
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
24
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
@@ -38,7 +38,7 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
|
|
38
38
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
39
39
|
liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
|
40
40
|
liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
|
41
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=
|
41
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
42
42
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
43
43
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
44
44
|
liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
|
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
74
74
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
75
75
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
76
76
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
77
|
-
liger_kernel_nightly-0.5.8.
|
78
|
-
liger_kernel_nightly-0.5.8.
|
79
|
-
liger_kernel_nightly-0.5.8.
|
80
|
-
liger_kernel_nightly-0.5.8.
|
81
|
-
liger_kernel_nightly-0.5.8.
|
82
|
-
liger_kernel_nightly-0.5.8.
|
77
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
78
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/METADATA,sha256=M3ZnXyCzfuYgFnBj7dbF6_i9YJ3OdWrRQDrbTkBB8rs,23297
|
79
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
80
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
81
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
82
|
+
liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|