liger-kernel-nightly 0.5.8.dev20250428050809__py3-none-any.whl → 0.5.8.dev20250429233059__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
351
351
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
352
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
353
  pass
354
-
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
358
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
359
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
360
  else:
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
143
143
  alpha=1.0,
144
144
  )
145
145
 
146
- if reduction == "none":
147
- loss = loss_1d
148
- z_loss = z_loss_1d if return_z_loss else None
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
149
150
 
150
151
  else:
151
152
  loss = torch.sum(loss_1d)
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
23
  assert reduction in {
24
24
  "mean",
25
25
  "sum",
26
- "none",
27
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
+ }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
28
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
28
  self.ce_weight = ce_weight
30
29
  self.ignore_index = ignore_index
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250428050809
3
+ Version: 0.5.8.dev20250429233059
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -16,9 +16,9 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
19
+ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
20
20
  liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
21
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
21
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
22
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
23
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -38,7 +38,7 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
38
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
39
39
  liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
40
  liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
41
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
41
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
42
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
43
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
44
44
  liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/METADATA,sha256=JjME6LPHPZBmQ_lQlX3mWasYbAkKd0r6ZgEfsyeIGx8,23297
79
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/METADATA,sha256=M3ZnXyCzfuYgFnBj7dbF6_i9YJ3OdWrRQDrbTkBB8rs,23297
79
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/RECORD,,