liger-kernel-nightly 0.6.2.dev20251014205028__py3-none-any.whl → 0.6.2.dev20251016055812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -414,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
414
414
  Returns:
415
415
  tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
416
416
  """
417
+ input_requires_grad = _input.requires_grad
418
+
417
419
  loss, z_loss, _input = cross_entropy_forward(
418
420
  _input,
419
421
  target,
@@ -428,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
428
430
  # TODO: investigation
429
431
  # If we don't detach the _input tensor, the memory will double
430
432
  # Not sure why but seems that there will be a time both grad and value exist but in different location
431
- ctx.save_for_backward(_input.detach())
433
+ if input_requires_grad:
434
+ ctx.save_for_backward(_input.detach())
432
435
  ctx.return_z_loss = return_z_loss
433
436
 
434
437
  return loss, z_loss
@@ -31,6 +31,8 @@ def fused_linear_cross_entropy_forward(
31
31
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
32
32
  device = _input.device
33
33
 
34
+ input_requires_grad = _input.requires_grad
35
+
34
36
  # inputs have shape: BT x H
35
37
  # materialized activations will have shape: BT x V
36
38
  # the increase in memory = BT x V
@@ -49,12 +51,13 @@ def fused_linear_cross_entropy_forward(
49
51
  grad_input = torch.zeros_like(_input, device=device)
50
52
 
51
53
  # we use fp32 for loss and gradients accumulator
52
- if accum_dtype is None:
53
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
54
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
55
- else:
56
- grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
57
- grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
54
+ if input_requires_grad:
55
+ if accum_dtype is None:
56
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
57
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
58
+ else:
59
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
60
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
58
61
 
59
62
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
60
63
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
@@ -150,7 +153,7 @@ def fused_linear_cross_entropy_forward(
150
153
  RETURN_Z_LOSS=return_z_loss,
151
154
  HAS_WEIGHT=True if ce_weight is not None else False,
152
155
  HAS_SOFTCAPPING=True if softcap is not None else False,
153
- HAS_GRADIENTS=_input.requires_grad,
156
+ HAS_GRADIENTS=input_requires_grad,
154
157
  BLOCK_SIZE=BLOCK_SIZE,
155
158
  num_warps=32 if not is_hip() else 16,
156
159
  )
@@ -172,12 +175,13 @@ def fused_linear_cross_entropy_forward(
172
175
  scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
173
176
  grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
174
177
 
175
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
178
+ if input_requires_grad:
179
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
176
180
 
177
- if grad_weight is not None and _input.requires_grad:
181
+ if grad_weight is not None and input_requires_grad:
178
182
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
179
183
 
180
- if bias is not None and _input.requires_grad:
184
+ if bias is not None and input_requires_grad:
181
185
  torch.add(
182
186
  input=grad_bias,
183
187
  other=grad_logits_chunk.sum(dim=0),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251014205028
3
+ Version: 0.6.2.dev20251016055812
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -17,10 +17,10 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
17
17
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
18
18
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
19
19
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- liger_kernel/ops/cross_entropy.py,sha256=OVkani9JEmCJ8IHN3UgJKzGW7zxJWDwy1EaWVcbShgQ,19517
20
+ liger_kernel/ops/cross_entropy.py,sha256=CEgAeX97ezIBRhK3dPQRKsEQiwgnBDOewtDoqKXzw_Q,19605
21
21
  liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
22
22
  liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
23
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=PqIPHU8EjkHRJF6cNZViDucFVOgqo7eanJxB53Npke8,14388
23
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=rL6PyM4_9CLj7OL6qHa_ssFdJn0JEZlE12znF7T5cvM,14521
24
24
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
25
25
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
26
26
  liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
@@ -101,9 +101,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
101
101
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
102
102
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
103
103
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
104
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
105
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/METADATA,sha256=6VDasn5yo1wPa73CAIS4iRzr6TJ_cWpSjF_QbD5r1sM,24777
106
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
107
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
108
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
109
- liger_kernel_nightly-0.6.2.dev20251014205028.dist-info/RECORD,,
104
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
105
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/METADATA,sha256=0T7yuosaQopminlzrQ4Z2ZyY7Lm_Dst67jQScbOIlHU,24777
106
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
107
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
108
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
109
+ liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/RECORD,,