liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  2. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  3. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  4. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  5. liger_kernel/ops/cross_entropy.py +59 -53
  6. liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
  7. liger_kernel/ops/layer_norm.py +4 -6
  8. liger_kernel/ops/llama4_rope.py +225 -0
  9. liger_kernel/ops/poly_norm.py +386 -0
  10. liger_kernel/transformers/__init__.py +32 -0
  11. liger_kernel/transformers/experimental/__init__.py +5 -0
  12. liger_kernel/transformers/functional.py +9 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
  14. liger_kernel/transformers/llama4_rope.py +93 -0
  15. liger_kernel/transformers/model/falcon_h1.py +108 -0
  16. liger_kernel/transformers/model/gemma.py +2 -1
  17. liger_kernel/transformers/model/gemma2.py +8 -2
  18. liger_kernel/transformers/model/gemma3.py +27 -2
  19. liger_kernel/transformers/model/glm4.py +2 -1
  20. liger_kernel/transformers/model/glm4v.py +151 -0
  21. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  22. liger_kernel/transformers/model/internvl.py +150 -0
  23. liger_kernel/transformers/model/llama.py +2 -1
  24. liger_kernel/transformers/model/llama4.py +2 -1
  25. liger_kernel/transformers/model/llava.py +6 -2
  26. liger_kernel/transformers/model/loss_utils.py +3 -0
  27. liger_kernel/transformers/model/mistral.py +2 -1
  28. liger_kernel/transformers/model/mixtral.py +8 -2
  29. liger_kernel/transformers/model/mllama.py +6 -3
  30. liger_kernel/transformers/model/olmo2.py +2 -1
  31. liger_kernel/transformers/model/paligemma.py +19 -0
  32. liger_kernel/transformers/model/phi3.py +10 -160
  33. liger_kernel/transformers/model/qwen2.py +2 -1
  34. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  35. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  36. liger_kernel/transformers/model/qwen3.py +2 -1
  37. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  38. liger_kernel/transformers/model/qwen3_next.py +134 -0
  39. liger_kernel/transformers/model/smollm3.py +2 -1
  40. liger_kernel/transformers/model/smolvlm.py +158 -0
  41. liger_kernel/transformers/monkey_patch.py +552 -23
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +42 -0
  44. liger_kernel/transformers/rms_norm.py +7 -0
  45. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
  46. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
  47. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  48. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  49. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  50. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
70
113
  use_ref_model=True,
71
114
  average_log_prob=False,
72
115
  chunk_size=1,
116
+ loss_type="sigmoid",
73
117
  ):
74
118
  """
75
119
  Fused linear layer with DPO loss.
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
108
152
  ref_bias=ref_bias,
109
153
  average_log_prob=average_log_prob,
110
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
111
156
  )
112
157
 
113
158
  @staticmethod
114
159
  def backward(ctx, *grad_output):
115
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
116
- return *grads, None, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
117
162
 
118
163
 
119
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
130
175
  use_ref_model: bool = True,
131
176
  average_log_prob: bool = False,
132
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
133
179
  ):
134
180
  """
135
181
  Args:
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
149
195
  self.use_ref_model = use_ref_model
150
196
  self.average_log_prob = average_log_prob
151
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
152
202
 
153
203
  def forward(
154
204
  self,
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
175
225
  self.use_ref_model,
176
226
  self.average_log_prob,
177
227
  self.chunk_size,
228
+ self.loss_type,
178
229
  )
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
34
34
  beta=0.04,
35
35
  loss_type="bnpo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
263
  beta=0.04,
262
264
  loss_type="bnpo",
263
265
  max_completion_length=None,
266
+ importance_sampling_level="token",
264
267
  temperature=1.0,
265
268
  use_ref_model=False,
266
269
  ppo_loss_fn=None,
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
295
  beta=beta,
293
296
  loss_type=loss_type,
294
297
  max_completion_length=max_completion_length,
298
+ importance_sampling_level=importance_sampling_level,
295
299
  )
296
300
 
297
301
  return chunk_loss, chunk_metrics
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
31
31
  beta=0.04,
32
32
  loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
101
  metrics = []
86
102
  if beta != 0.0:
87
103
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
104
+
105
+ # Adjust clipping metric calculation based on importance sampling level
106
+ if importance_sampling_level == "token":
107
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
108
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
109
+ )
110
+ else: # sequence level
111
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
112
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
113
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
114
+ )
115
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116
+
91
117
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
118
  return loss, metrics
93
119
 
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
111
137
  epsilon_high=0.2,
112
138
  loss_type="bnpo",
113
139
  max_completion_length=None,
140
+ importance_sampling_level="token",
114
141
  temperature=1.0,
115
142
  compiled=True,
116
143
  use_ref_model=True,
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
132
159
  beta (float): Weight for the KL penalty
133
160
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
134
161
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
163
  temperature (float): Temperature for the logits
136
164
  compiled (bool): Whether to use torch compile
137
165
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
190
  compiled=compiled,
163
191
  use_ref_model=use_ref_model,
164
192
  chunk_size=chunk_size,
193
+ importance_sampling_level=importance_sampling_level,
165
194
  )
166
195
 
167
196
  @staticmethod
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
216
  None, # grad_epsilon_high
188
217
  None, # grad_loss_type (string, not differentiable)
189
218
  None, # grad_max_completion_length (int, not differentiable)
219
+ None, # grad_importance_sampling_level (string, not differentiable)
190
220
  None, # grad_temperature
191
221
  None, # grad_compiled
192
222
  None, # grad_use_ref_model
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
207
237
  epsilon_high: float = 0.2,
208
238
  loss_type: str = "bnpo",
209
239
  max_completion_length: Optional[int] = None,
240
+ importance_sampling_level: str = "token",
210
241
  temperature: float = 1.0,
211
242
  ):
212
243
  """
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
219
250
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
251
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
221
252
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
254
  temperature (float): Temperature for the logits.
223
255
  """
224
256
  super().__init__()
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
262
  self.epsilon_high = epsilon_high
231
263
  self.loss_type = loss_type
232
264
  self.max_completion_length = max_completion_length
265
+ self.importance_sampling_level = importance_sampling_level
233
266
  self.temperature = temperature
234
267
 
235
268
  def forward(
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
296
  self.epsilon_high,
264
297
  self.loss_type,
265
298
  self.max_completion_length,
299
+ self.importance_sampling_level,
266
300
  self.temperature,
267
301
  self.compiled,
268
302
  self.use_ref_model,
@@ -1,3 +1,5 @@
1
+ import math
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
 
@@ -25,8 +27,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
25
27
  jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
28
  else:
27
29
  # Compute probabilities (only required for mean calculation)
28
- mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
- log_mean_probs = mean_probs.log()
30
+ log_mean_probs = torch.logsumexp(
31
+ torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
32
+ )
30
33
 
31
34
  student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
35
  teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
45
45
  BLOCK_SIZE: tl.constexpr,
46
46
  HAS_WEIGHT: tl.constexpr,
47
47
  HAS_SOFTCAPPING: tl.constexpr,
48
+ HAS_GRADIENTS: tl.constexpr,
48
49
  ):
49
50
  """
50
51
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
72
73
  BLOCK_SIZE (int): The block size for Triton operations.
73
74
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
75
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
77
  """
76
78
 
77
79
  # https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155
157
  # For 'sum' reduction, no normalization is applied:
156
158
  # dx_y = softmax(x_y) - 1
157
159
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
160
+ if HAS_GRADIENTS:
161
+ for i in range(0, n_cols, BLOCK_SIZE):
162
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
163
+ X_block = tl.load(
164
+ X_ptr + X_offsets,
165
+ mask=X_offsets < n_cols,
166
+ other=float("-inf"),
167
+ # Ensure float32 precision for softmax calculation
168
+ ).cast(tl.float32)
169
+ if HAS_SOFTCAPPING:
170
+ intermediate = tanh(X_block / softcap)
171
+ X_block = softcap * intermediate
172
+
173
+ if not HAS_WEIGHT:
174
+ # softmax(x_i)
175
+ X_block = tl.exp(X_block - m) / d
176
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177
+ X_block += 2 * lse_square_scale * lse * X_block
178
+ # smoothing term
179
+ X_block += -eps
180
+ # special handle dx_y
181
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
182
+ # reduction scale
183
+ if reduction == "mean":
184
+ X_block = X_block / n_non_ignore
185
+ else:
186
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
187
+ softmax_X = tl.exp(X_block - m) / d
188
+ # derivative of original_loss
189
+ dloss_ori = (1 - label_smoothing) * softmax_X
190
+ # specially handle dx_y
191
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
192
+ dloss_ori = dloss_ori * weight_y
193
+ # derivative of smooth_loss
194
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
195
+ # derivative of z-loss
196
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
197
+ # reduction scale
198
+ if reduction == "mean":
199
+ dloss_ori = dloss_ori / sum_non_ignore_weight
200
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
201
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202
+ dz_loss = dz_loss / n_non_ignore
203
+ # derivative of total_loss
204
+ X_block = dloss_ori + dloss_smooth + dz_loss
205
+
206
+ # chain rule softcapping
207
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208
+ if HAS_SOFTCAPPING:
209
+ X_block = X_block * (1 - intermediate * intermediate)
210
+
211
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
212
 
211
213
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
214
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332
334
  BLOCK_SIZE=BLOCK_SIZE,
333
335
  HAS_WEIGHT=True if weight is not None else False,
334
336
  HAS_SOFTCAPPING=True if softcap is not None else False,
337
+ HAS_GRADIENTS=_input.requires_grad,
335
338
  # TODO: 32 seems to give the best performance
336
339
  # Performance is quite sensitive to num_warps
337
340
  num_warps=32 if not is_hip() else 16,
@@ -411,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
411
414
  Returns:
412
415
  tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
413
416
  """
417
+ input_requires_grad = _input.requires_grad
418
+
414
419
  loss, z_loss, _input = cross_entropy_forward(
415
420
  _input,
416
421
  target,
@@ -425,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
425
430
  # TODO: investigation
426
431
  # If we don't detach the _input tensor, the memory will double
427
432
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
433
+ if input_requires_grad:
434
+ ctx.save_for_backward(_input.detach())
429
435
  ctx.return_z_loss = return_z_loss
430
436
 
431
437
  return loss, z_loss
@@ -25,10 +25,14 @@ def fused_linear_cross_entropy_forward(
25
25
  reduction="mean",
26
26
  softcap=None,
27
27
  return_z_loss=False,
28
+ accum_dtype=None,
29
+ use_token_scaling=False,
28
30
  ):
29
31
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
30
32
  device = _input.device
31
33
 
34
+ input_requires_grad = _input.requires_grad
35
+
32
36
  # inputs have shape: BT x H
33
37
  # materialized activations will have shape: BT x V
34
38
  # the increase in memory = BT x V
@@ -44,10 +48,17 @@ def fused_linear_cross_entropy_forward(
44
48
  chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
45
49
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
46
50
 
47
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
48
51
  grad_input = torch.zeros_like(_input, device=device)
49
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
50
- # we use fp32 for loss accumulator
52
+
53
+ # we use fp32 for loss and gradients accumulator
54
+ if input_requires_grad:
55
+ if accum_dtype is None:
56
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
57
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
58
+ else:
59
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
60
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
61
+
51
62
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
52
63
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
53
64
 
@@ -82,6 +93,36 @@ def fused_linear_cross_entropy_forward(
82
93
 
83
94
  n_rows = logits_chunk.shape[0]
84
95
 
96
+ # Compute predicted probabilities for token scaling if needed
97
+ if use_token_scaling:
98
+ # Compute softmax probabilities for scaling
99
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
100
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
101
+ if softcap is not None:
102
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
103
+
104
+ # Compute softmax to get predicted probabilities
105
+ probs = torch.softmax(logits_for_softmax, dim=-1)
106
+
107
+ # Get predicted probabilities for token scaling, handling ignored targets
108
+ valid_target_mask = target_chunk != ignore_index
109
+ valid_targets = target_chunk[valid_target_mask]
110
+
111
+ if len(valid_targets) > 0:
112
+ # Gather probabilities only for valid targets
113
+ valid_probs = probs[valid_target_mask]
114
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
115
+
116
+ # Create full tensor with zeros for ignored targets
117
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
118
+ pred_probs[valid_target_mask] = pred_probs_valid
119
+ else:
120
+ # All targets are ignored
121
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
122
+
123
+ # Store the scaling factors
124
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
125
+
85
126
  # unreduced loss
86
127
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
87
128
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
@@ -112,33 +153,38 @@ def fused_linear_cross_entropy_forward(
112
153
  RETURN_Z_LOSS=return_z_loss,
113
154
  HAS_WEIGHT=True if ce_weight is not None else False,
114
155
  HAS_SOFTCAPPING=True if softcap is not None else False,
156
+ HAS_GRADIENTS=input_requires_grad,
115
157
  BLOCK_SIZE=BLOCK_SIZE,
116
158
  num_warps=32 if not is_hip() else 16,
117
159
  )
118
160
 
161
+ # Apply token scaling if requested
162
+ if use_token_scaling:
163
+ loss_1d_slice = loss_1d_slice * scaling_factors
164
+ if return_z_loss:
165
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
166
+
119
167
  loss_1d[start_idx:end_idx] = loss_1d_slice
120
168
  if return_z_loss:
121
169
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
122
170
  grad_logits_chunk = logits_chunk # chunk_size x V
123
171
 
124
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
172
+ # Apply token scaling to gradients if requested
173
+ if use_token_scaling:
174
+ # Expand scaling factors to match gradient dimensions
175
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
176
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
125
177
 
126
- if grad_weight is not None:
127
- torch.addmm(
128
- input=grad_weight,
129
- mat1=logits_chunk.t().to(
130
- _input_chunk.dtype
131
- ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
132
- mat2=_input_chunk,
133
- out=grad_weight,
134
- alpha=1.0,
135
- beta=1.0,
136
- )
178
+ if input_requires_grad:
179
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
137
180
 
138
- if bias is not None:
181
+ if grad_weight is not None and input_requires_grad:
182
+ grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
183
+
184
+ if bias is not None and input_requires_grad:
139
185
  torch.add(
140
186
  input=grad_bias,
141
- other=logits_chunk.sum(dim=0),
187
+ other=grad_logits_chunk.sum(dim=0),
142
188
  out=grad_bias,
143
189
  alpha=1.0,
144
190
  )
@@ -148,9 +194,18 @@ def fused_linear_cross_entropy_forward(
148
194
  # loss = loss_1d
149
195
  # z_loss = z_loss_1d if return_z_loss else None
150
196
 
197
+ if reduction == "none":
198
+ # Return per-token losses
199
+ loss = loss_1d
200
+ z_loss = z_loss_1d if return_z_loss else None
151
201
  else:
152
202
  loss = torch.sum(loss_1d)
153
203
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
204
+
205
+ # Cast back to original dtype
206
+ grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
207
+ grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
208
+
154
209
  return loss, z_loss, grad_input, grad_weight, grad_bias
155
210
 
156
211
 
@@ -217,6 +272,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
217
272
  reduction="mean",
218
273
  softcap=None,
219
274
  return_z_loss: bool = False,
275
+ accum_dtype=None,
276
+ use_token_scaling: bool = False,
220
277
  ):
221
278
  """
222
279
  Fusing the last linear layer with cross-entropy loss
@@ -235,6 +292,11 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
235
292
  ignore_index: the index to ignore in the target
236
293
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
237
294
  reduction: reduction to apply
295
+ accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
296
+ Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
297
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
298
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
299
+ Default: False.
238
300
  """
239
301
 
240
302
  loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
@@ -249,6 +311,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
249
311
  reduction=reduction,
250
312
  softcap=softcap,
251
313
  return_z_loss=return_z_loss,
314
+ accum_dtype=accum_dtype,
315
+ use_token_scaling=use_token_scaling,
252
316
  )
253
317
  # downcast to dtype and store for backward
254
318
  ctx.save_for_backward(
@@ -280,4 +344,6 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
280
344
  None,
281
345
  None,
282
346
  None,
347
+ None,
348
+ None, # use_token_scaling
283
349
  )
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
63
63
  X_f32 = X_row.to(tl.float32)
64
64
 
65
65
  # Compute statistics in fp32 for numerical stability
66
- n_cols_f32 = n_cols.to(tl.float32)
67
- mean = tl.sum(X_f32, axis=0) / n_cols_f32
66
+ mean = tl.sum(X_f32, axis=0) / n_cols
68
67
  X_centered = X_f32 - mean
69
68
  # Apply mask to variance calculation to exclude contributions from masked elements
70
69
  X_centered_masked = tl.where(mask, X_centered, 0.0)
71
- var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
70
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
72
71
  rstd = rsqrt(var + eps)
73
72
 
74
73
  # Store statistics (convert back to original dtype only once)
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
113
112
  # Pre-load weights once (same optimization as forward pass)
114
113
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
114
  w_f32 = w.to(tl.float32)
116
- n_cols_f32 = n_cols.to(tl.float32)
117
115
 
118
116
  # Calculate pointers for this specific row
119
117
  row_X_ptr = X_ptr + row_idx * stride_x
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
137
135
  # Compute backward pass for this row
138
136
  x_hat = (x_f32 - mean_f32) * rstd_f32
139
137
  wdy = w_f32 * dy_f32
140
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
- c2 = tl.sum(wdy, axis=0) / n_cols_f32
138
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
+ c2 = tl.sum(wdy, axis=0) / n_cols
142
140
  dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
141
 
144
142
  # Store input gradient