liger-kernel 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  2. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  3. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  4. liger_kernel/ops/cross_entropy.py +59 -53
  5. liger_kernel/ops/fused_linear_cross_entropy.py +68 -10
  6. liger_kernel/ops/layer_norm.py +4 -6
  7. liger_kernel/ops/poly_norm.py +386 -0
  8. liger_kernel/transformers/__init__.py +17 -0
  9. liger_kernel/transformers/functional.py +7 -0
  10. liger_kernel/transformers/fused_linear_cross_entropy.py +5 -1
  11. liger_kernel/transformers/model/falcon_h1.py +108 -0
  12. liger_kernel/transformers/model/gemma.py +2 -1
  13. liger_kernel/transformers/model/gemma2.py +8 -2
  14. liger_kernel/transformers/model/gemma3.py +27 -2
  15. liger_kernel/transformers/model/glm4.py +2 -1
  16. liger_kernel/transformers/model/glm4v.py +3 -2
  17. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  18. liger_kernel/transformers/model/internvl.py +150 -0
  19. liger_kernel/transformers/model/llama.py +2 -1
  20. liger_kernel/transformers/model/llama4.py +2 -1
  21. liger_kernel/transformers/model/llava.py +6 -2
  22. liger_kernel/transformers/model/loss_utils.py +1 -0
  23. liger_kernel/transformers/model/mistral.py +2 -1
  24. liger_kernel/transformers/model/mixtral.py +8 -2
  25. liger_kernel/transformers/model/mllama.py +2 -1
  26. liger_kernel/transformers/model/olmo2.py +2 -1
  27. liger_kernel/transformers/model/paligemma.py +19 -0
  28. liger_kernel/transformers/model/phi3.py +2 -1
  29. liger_kernel/transformers/model/qwen2.py +2 -1
  30. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  31. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  32. liger_kernel/transformers/model/qwen3.py +2 -1
  33. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  34. liger_kernel/transformers/model/qwen3_next.py +134 -0
  35. liger_kernel/transformers/model/smollm3.py +2 -1
  36. liger_kernel/transformers/model/smolvlm.py +158 -0
  37. liger_kernel/transformers/monkey_patch.py +452 -3
  38. liger_kernel/transformers/multi_token_attention.py +1 -1
  39. liger_kernel/transformers/poly_norm.py +42 -0
  40. liger_kernel/transformers/rms_norm.py +7 -0
  41. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +13 -10
  42. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +46 -39
  43. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  44. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  45. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  46. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
34
34
  beta=0.04,
35
35
  loss_type="bnpo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
263
  beta=0.04,
262
264
  loss_type="bnpo",
263
265
  max_completion_length=None,
266
+ importance_sampling_level="token",
264
267
  temperature=1.0,
265
268
  use_ref_model=False,
266
269
  ppo_loss_fn=None,
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
295
  beta=beta,
293
296
  loss_type=loss_type,
294
297
  max_completion_length=max_completion_length,
298
+ importance_sampling_level=importance_sampling_level,
295
299
  )
296
300
 
297
301
  return chunk_loss, chunk_metrics
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
31
31
  beta=0.04,
32
32
  loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
101
  metrics = []
86
102
  if beta != 0.0:
87
103
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
104
+
105
+ # Adjust clipping metric calculation based on importance sampling level
106
+ if importance_sampling_level == "token":
107
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
108
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
109
+ )
110
+ else: # sequence level
111
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
112
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
113
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
114
+ )
115
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116
+
91
117
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
118
  return loss, metrics
93
119
 
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
111
137
  epsilon_high=0.2,
112
138
  loss_type="bnpo",
113
139
  max_completion_length=None,
140
+ importance_sampling_level="token",
114
141
  temperature=1.0,
115
142
  compiled=True,
116
143
  use_ref_model=True,
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
132
159
  beta (float): Weight for the KL penalty
133
160
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
134
161
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
163
  temperature (float): Temperature for the logits
136
164
  compiled (bool): Whether to use torch compile
137
165
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
190
  compiled=compiled,
163
191
  use_ref_model=use_ref_model,
164
192
  chunk_size=chunk_size,
193
+ importance_sampling_level=importance_sampling_level,
165
194
  )
166
195
 
167
196
  @staticmethod
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
216
  None, # grad_epsilon_high
188
217
  None, # grad_loss_type (string, not differentiable)
189
218
  None, # grad_max_completion_length (int, not differentiable)
219
+ None, # grad_importance_sampling_level (string, not differentiable)
190
220
  None, # grad_temperature
191
221
  None, # grad_compiled
192
222
  None, # grad_use_ref_model
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
207
237
  epsilon_high: float = 0.2,
208
238
  loss_type: str = "bnpo",
209
239
  max_completion_length: Optional[int] = None,
240
+ importance_sampling_level: str = "token",
210
241
  temperature: float = 1.0,
211
242
  ):
212
243
  """
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
219
250
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
251
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
221
252
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
254
  temperature (float): Temperature for the logits.
223
255
  """
224
256
  super().__init__()
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
262
  self.epsilon_high = epsilon_high
231
263
  self.loss_type = loss_type
232
264
  self.max_completion_length = max_completion_length
265
+ self.importance_sampling_level = importance_sampling_level
233
266
  self.temperature = temperature
234
267
 
235
268
  def forward(
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
296
  self.epsilon_high,
264
297
  self.loss_type,
265
298
  self.max_completion_length,
299
+ self.importance_sampling_level,
266
300
  self.temperature,
267
301
  self.compiled,
268
302
  self.use_ref_model,
@@ -1,3 +1,5 @@
1
+ import math
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
 
@@ -25,8 +27,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
25
27
  jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
28
  else:
27
29
  # Compute probabilities (only required for mean calculation)
28
- mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
- log_mean_probs = mean_probs.log()
30
+ log_mean_probs = torch.logsumexp(
31
+ torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
32
+ )
30
33
 
31
34
  student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
35
  teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
45
45
  BLOCK_SIZE: tl.constexpr,
46
46
  HAS_WEIGHT: tl.constexpr,
47
47
  HAS_SOFTCAPPING: tl.constexpr,
48
+ HAS_GRADIENTS: tl.constexpr,
48
49
  ):
49
50
  """
50
51
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
72
73
  BLOCK_SIZE (int): The block size for Triton operations.
73
74
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
75
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
77
  """
76
78
 
77
79
  # https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155
157
  # For 'sum' reduction, no normalization is applied:
156
158
  # dx_y = softmax(x_y) - 1
157
159
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
160
+ if HAS_GRADIENTS:
161
+ for i in range(0, n_cols, BLOCK_SIZE):
162
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
163
+ X_block = tl.load(
164
+ X_ptr + X_offsets,
165
+ mask=X_offsets < n_cols,
166
+ other=float("-inf"),
167
+ # Ensure float32 precision for softmax calculation
168
+ ).cast(tl.float32)
169
+ if HAS_SOFTCAPPING:
170
+ intermediate = tanh(X_block / softcap)
171
+ X_block = softcap * intermediate
172
+
173
+ if not HAS_WEIGHT:
174
+ # softmax(x_i)
175
+ X_block = tl.exp(X_block - m) / d
176
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177
+ X_block += 2 * lse_square_scale * lse * X_block
178
+ # smoothing term
179
+ X_block += -eps
180
+ # special handle dx_y
181
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
182
+ # reduction scale
183
+ if reduction == "mean":
184
+ X_block = X_block / n_non_ignore
185
+ else:
186
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
187
+ softmax_X = tl.exp(X_block - m) / d
188
+ # derivative of original_loss
189
+ dloss_ori = (1 - label_smoothing) * softmax_X
190
+ # specially handle dx_y
191
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
192
+ dloss_ori = dloss_ori * weight_y
193
+ # derivative of smooth_loss
194
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
195
+ # derivative of z-loss
196
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
197
+ # reduction scale
198
+ if reduction == "mean":
199
+ dloss_ori = dloss_ori / sum_non_ignore_weight
200
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
201
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202
+ dz_loss = dz_loss / n_non_ignore
203
+ # derivative of total_loss
204
+ X_block = dloss_ori + dloss_smooth + dz_loss
205
+
206
+ # chain rule softcapping
207
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208
+ if HAS_SOFTCAPPING:
209
+ X_block = X_block * (1 - intermediate * intermediate)
210
+
211
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
212
 
211
213
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
214
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332
334
  BLOCK_SIZE=BLOCK_SIZE,
333
335
  HAS_WEIGHT=True if weight is not None else False,
334
336
  HAS_SOFTCAPPING=True if softcap is not None else False,
337
+ HAS_GRADIENTS=_input.requires_grad,
335
338
  # TODO: 32 seems to give the best performance
336
339
  # Performance is quite sensitive to num_warps
337
340
  num_warps=32 if not is_hip() else 16,
@@ -411,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
411
414
  Returns:
412
415
  tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
413
416
  """
417
+ input_requires_grad = _input.requires_grad
418
+
414
419
  loss, z_loss, _input = cross_entropy_forward(
415
420
  _input,
416
421
  target,
@@ -425,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
425
430
  # TODO: investigation
426
431
  # If we don't detach the _input tensor, the memory will double
427
432
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
433
+ if input_requires_grad:
434
+ ctx.save_for_backward(_input.detach())
429
435
  ctx.return_z_loss = return_z_loss
430
436
 
431
437
  return loss, z_loss
@@ -26,10 +26,13 @@ def fused_linear_cross_entropy_forward(
26
26
  softcap=None,
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
+ use_token_scaling=False,
29
30
  ):
30
31
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
31
32
  device = _input.device
32
33
 
34
+ input_requires_grad = _input.requires_grad
35
+
33
36
  # inputs have shape: BT x H
34
37
  # materialized activations will have shape: BT x V
35
38
  # the increase in memory = BT x V
@@ -48,12 +51,13 @@ def fused_linear_cross_entropy_forward(
48
51
  grad_input = torch.zeros_like(_input, device=device)
49
52
 
50
53
  # we use fp32 for loss and gradients accumulator
51
- if accum_dtype is None:
52
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
53
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
54
- else:
55
- grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
56
- grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
54
+ if input_requires_grad:
55
+ if accum_dtype is None:
56
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
57
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
58
+ else:
59
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
60
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
57
61
 
58
62
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
59
63
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
@@ -89,6 +93,36 @@ def fused_linear_cross_entropy_forward(
89
93
 
90
94
  n_rows = logits_chunk.shape[0]
91
95
 
96
+ # Compute predicted probabilities for token scaling if needed
97
+ if use_token_scaling:
98
+ # Compute softmax probabilities for scaling
99
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
100
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
101
+ if softcap is not None:
102
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
103
+
104
+ # Compute softmax to get predicted probabilities
105
+ probs = torch.softmax(logits_for_softmax, dim=-1)
106
+
107
+ # Get predicted probabilities for token scaling, handling ignored targets
108
+ valid_target_mask = target_chunk != ignore_index
109
+ valid_targets = target_chunk[valid_target_mask]
110
+
111
+ if len(valid_targets) > 0:
112
+ # Gather probabilities only for valid targets
113
+ valid_probs = probs[valid_target_mask]
114
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
115
+
116
+ # Create full tensor with zeros for ignored targets
117
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
118
+ pred_probs[valid_target_mask] = pred_probs_valid
119
+ else:
120
+ # All targets are ignored
121
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
122
+
123
+ # Store the scaling factors
124
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
125
+
92
126
  # unreduced loss
93
127
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
94
128
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
@@ -119,24 +153,38 @@ def fused_linear_cross_entropy_forward(
119
153
  RETURN_Z_LOSS=return_z_loss,
120
154
  HAS_WEIGHT=True if ce_weight is not None else False,
121
155
  HAS_SOFTCAPPING=True if softcap is not None else False,
156
+ HAS_GRADIENTS=input_requires_grad,
122
157
  BLOCK_SIZE=BLOCK_SIZE,
123
158
  num_warps=32 if not is_hip() else 16,
124
159
  )
125
160
 
161
+ # Apply token scaling if requested
162
+ if use_token_scaling:
163
+ loss_1d_slice = loss_1d_slice * scaling_factors
164
+ if return_z_loss:
165
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
166
+
126
167
  loss_1d[start_idx:end_idx] = loss_1d_slice
127
168
  if return_z_loss:
128
169
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
129
170
  grad_logits_chunk = logits_chunk # chunk_size x V
130
171
 
131
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
172
+ # Apply token scaling to gradients if requested
173
+ if use_token_scaling:
174
+ # Expand scaling factors to match gradient dimensions
175
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
176
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
132
177
 
133
- if grad_weight is not None:
178
+ if input_requires_grad:
179
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
180
+
181
+ if grad_weight is not None and input_requires_grad:
134
182
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
135
183
 
136
- if bias is not None:
184
+ if bias is not None and input_requires_grad:
137
185
  torch.add(
138
186
  input=grad_bias,
139
- other=logits_chunk.sum(dim=0),
187
+ other=grad_logits_chunk.sum(dim=0),
140
188
  out=grad_bias,
141
189
  alpha=1.0,
142
190
  )
@@ -146,6 +194,10 @@ def fused_linear_cross_entropy_forward(
146
194
  # loss = loss_1d
147
195
  # z_loss = z_loss_1d if return_z_loss else None
148
196
 
197
+ if reduction == "none":
198
+ # Return per-token losses
199
+ loss = loss_1d
200
+ z_loss = z_loss_1d if return_z_loss else None
149
201
  else:
150
202
  loss = torch.sum(loss_1d)
151
203
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
@@ -221,6 +273,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
221
273
  softcap=None,
222
274
  return_z_loss: bool = False,
223
275
  accum_dtype=None,
276
+ use_token_scaling: bool = False,
224
277
  ):
225
278
  """
226
279
  Fusing the last linear layer with cross-entropy loss
@@ -241,6 +294,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
241
294
  reduction: reduction to apply
242
295
  accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
243
296
  Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
297
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
298
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
299
+ Default: False.
244
300
  """
245
301
 
246
302
  loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
@@ -256,6 +312,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
256
312
  softcap=softcap,
257
313
  return_z_loss=return_z_loss,
258
314
  accum_dtype=accum_dtype,
315
+ use_token_scaling=use_token_scaling,
259
316
  )
260
317
  # downcast to dtype and store for backward
261
318
  ctx.save_for_backward(
@@ -288,4 +345,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
288
345
  None,
289
346
  None,
290
347
  None,
348
+ None, # use_token_scaling
291
349
  )
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
63
63
  X_f32 = X_row.to(tl.float32)
64
64
 
65
65
  # Compute statistics in fp32 for numerical stability
66
- n_cols_f32 = n_cols.to(tl.float32)
67
- mean = tl.sum(X_f32, axis=0) / n_cols_f32
66
+ mean = tl.sum(X_f32, axis=0) / n_cols
68
67
  X_centered = X_f32 - mean
69
68
  # Apply mask to variance calculation to exclude contributions from masked elements
70
69
  X_centered_masked = tl.where(mask, X_centered, 0.0)
71
- var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
70
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
72
71
  rstd = rsqrt(var + eps)
73
72
 
74
73
  # Store statistics (convert back to original dtype only once)
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
113
112
  # Pre-load weights once (same optimization as forward pass)
114
113
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
114
  w_f32 = w.to(tl.float32)
116
- n_cols_f32 = n_cols.to(tl.float32)
117
115
 
118
116
  # Calculate pointers for this specific row
119
117
  row_X_ptr = X_ptr + row_idx * stride_x
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
137
135
  # Compute backward pass for this row
138
136
  x_hat = (x_f32 - mean_f32) * rstd_f32
139
137
  wdy = w_f32 * dy_f32
140
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
- c2 = tl.sum(wdy, axis=0) / n_cols_f32
138
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
+ c2 = tl.sum(wdy, axis=0) / n_cols
142
140
  dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
141
 
144
142
  # Store input gradient