liger-kernel-nightly 0.6.1.dev20250819173444__py3-none-any.whl → 0.6.2.dev20250822031319__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,7 @@ def fused_linear_cross_entropy_forward(
26
26
  softcap=None,
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
+ use_token_scaling=False,
29
30
  ):
30
31
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
31
32
  device = _input.device
@@ -89,6 +90,23 @@ def fused_linear_cross_entropy_forward(
89
90
 
90
91
  n_rows = logits_chunk.shape[0]
91
92
 
93
+ # Compute predicted probabilities for token scaling if needed
94
+ if use_token_scaling:
95
+ # Compute softmax probabilities for scaling
96
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
97
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
98
+ if softcap is not None:
99
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
100
+
101
+ # Compute softmax to get predicted probabilities
102
+ probs = torch.softmax(logits_for_softmax, dim=-1)
103
+
104
+ # Get the predicted probability for each target token
105
+ pred_probs = torch.gather(probs, -1, target_chunk.unsqueeze(-1)).squeeze(-1)
106
+
107
+ # Store the scaling factors
108
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
109
+
92
110
  # unreduced loss
93
111
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
94
112
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
@@ -123,11 +141,23 @@ def fused_linear_cross_entropy_forward(
123
141
  num_warps=32 if not is_hip() else 16,
124
142
  )
125
143
 
144
+ # Apply token scaling if requested
145
+ if use_token_scaling:
146
+ loss_1d_slice = loss_1d_slice * scaling_factors
147
+ if return_z_loss:
148
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
149
+
126
150
  loss_1d[start_idx:end_idx] = loss_1d_slice
127
151
  if return_z_loss:
128
152
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
129
153
  grad_logits_chunk = logits_chunk # chunk_size x V
130
154
 
155
+ # Apply token scaling to gradients if requested
156
+ if use_token_scaling:
157
+ # Expand scaling factors to match gradient dimensions
158
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
159
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
160
+
131
161
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
132
162
 
133
163
  if grad_weight is not None:
@@ -136,7 +166,7 @@ def fused_linear_cross_entropy_forward(
136
166
  if bias is not None:
137
167
  torch.add(
138
168
  input=grad_bias,
139
- other=logits_chunk.sum(dim=0),
169
+ other=grad_logits_chunk.sum(dim=0),
140
170
  out=grad_bias,
141
171
  alpha=1.0,
142
172
  )
@@ -146,6 +176,10 @@ def fused_linear_cross_entropy_forward(
146
176
  # loss = loss_1d
147
177
  # z_loss = z_loss_1d if return_z_loss else None
148
178
 
179
+ if reduction == "none":
180
+ # Return per-token losses
181
+ loss = loss_1d
182
+ z_loss = z_loss_1d if return_z_loss else None
149
183
  else:
150
184
  loss = torch.sum(loss_1d)
151
185
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
@@ -221,6 +255,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
221
255
  softcap=None,
222
256
  return_z_loss: bool = False,
223
257
  accum_dtype=None,
258
+ use_token_scaling: bool = False,
224
259
  ):
225
260
  """
226
261
  Fusing the last linear layer with cross-entropy loss
@@ -241,6 +276,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
241
276
  reduction: reduction to apply
242
277
  accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
243
278
  Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
279
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
280
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
281
+ Default: False.
244
282
  """
245
283
 
246
284
  loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
@@ -256,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
256
294
  softcap=softcap,
257
295
  return_z_loss=return_z_loss,
258
296
  accum_dtype=accum_dtype,
297
+ use_token_scaling=use_token_scaling,
259
298
  )
260
299
  # downcast to dtype and store for backward
261
300
  ctx.save_for_backward(
@@ -288,4 +327,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
288
327
  None,
289
328
  None,
290
329
  None,
330
+ None, # use_token_scaling
291
331
  )
@@ -65,6 +65,7 @@ def liger_fused_linear_cross_entropy(
65
65
  softcap: Optional[float] = None,
66
66
  return_z_loss: bool = False,
67
67
  accum_dtype=None,
68
+ use_token_scaling: bool = False,
68
69
  ):
69
70
  loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
70
71
  input,
@@ -79,6 +80,7 @@ def liger_fused_linear_cross_entropy(
79
80
  softcap,
80
81
  return_z_loss,
81
82
  accum_dtype,
83
+ use_token_scaling,
82
84
  )
83
85
  if not return_z_loss:
84
86
  return loss
@@ -16,6 +16,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
16
16
  softcap: Optional[float] = None,
17
17
  return_z_loss: bool = False,
18
18
  accum_dtype: Optional[torch.dtype] = None,
19
+ use_token_scaling: bool = False,
19
20
  ):
20
21
  super().__init__()
21
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -34,6 +35,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
34
35
  self.softcap = softcap
35
36
  self.return_z_loss = return_z_loss
36
37
  self.accum_dtype = accum_dtype
38
+ self.use_token_scaling = use_token_scaling
37
39
 
38
40
  def forward(self, lin_weight, _input, target, bias=None):
39
41
  loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
@@ -49,6 +51,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
49
51
  self.softcap,
50
52
  self.return_z_loss,
51
53
  self.accum_dtype,
54
+ self.use_token_scaling,
52
55
  )
53
56
  if not self.return_z_loss:
54
57
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.1.dev20250819173444
3
+ Version: 0.6.2.dev20250822031319
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -20,7 +20,7 @@ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
20
20
  liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
21
21
  liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
22
22
  liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
23
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=YFPXUOIZpM_4r7AlfjkwOgDhAE_0H2mFjdKtx8cv-T4,11594
23
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=AIlKMOnM3J7ZeAgPP1uvA3T4OIeRkz6TTr_Lg9XgZGY,13581
24
24
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
25
25
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
26
26
  liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
@@ -46,9 +46,9 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
46
46
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
47
47
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
48
48
  liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
49
- liger_kernel/transformers/functional.py,sha256=XkYk_zb8xsRMtZtouYmlX_Tyyr-QA3WigSPF36DECYk,7777
49
+ liger_kernel/transformers/functional.py,sha256=-vpz95wbv5wLpInjSG06KNHETsEgKnRIiV-lMYHVs68,7841
50
50
  liger_kernel/transformers/fused_add_rms_norm.py,sha256=7_Bzg-x6lLe6W1qG2DtjDALhEpNZlC6N5GppEs9cTYY,1199
51
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_5AaQT2mcUEO2T7JGJYQafz6A1Efn9d3-Z3xFO_Xe0o,1862
51
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=ZMxkiJzGz1KtqgAdsqPODq3bugHBx_80kPYcd5z-xmM,1990
52
52
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
53
53
  liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
54
54
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
96
96
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
97
97
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
98
98
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
99
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/METADATA,sha256=OaVW-70Zf6I4qZbU4W9HcUlXza8L-zhHOmyViKLUftQ,24504
101
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
- liger_kernel_nightly-0.6.1.dev20250819173444.dist-info/RECORD,,
99
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/METADATA,sha256=PCvLHS6_1ZGwgU1Gn5v-k9J6mIBDGzy7uHphfk9nO5o,24504
101
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
+ liger_kernel_nightly-0.6.2.dev20250822031319.dist-info/RECORD,,