liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,12 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import get_npu_core_count
12
+ from liger_kernel.ops.utils import set_large_grf_mode
11
13
  from liger_kernel.ops.utils import torch_to_triton_dtype
14
+ from liger_kernel.utils import is_npu_available
12
15
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
16
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
17
  try:
15
18
  # typical import path with dispatch available
16
19
  from triton.language.extra.libdevice import rsqrt
@@ -160,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
160
163
 
161
164
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
162
165
 
163
- dY_ptr += row_start * dY_row_stride
164
- dX_ptr += row_start * dX_row_stride
165
- if has_dS_out:
166
- dS_out_ptr += row_start * dS_out_row_stride
167
-
168
- X_ptr += row_start * X_row_stride
169
- RSTD_ptr += row_start
170
-
171
166
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
172
167
  W_row = W_row + offset
173
168
 
174
- for _ in range(row_start, row_end):
175
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
176
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
169
+ for row_idx in range(row_start, row_end):
170
+ dy_base = dY_ptr + row_idx * dY_row_stride
171
+ dx_base = dX_ptr + row_idx * dX_row_stride
172
+
173
+ x_base = X_ptr + row_idx * X_row_stride
174
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
175
+
176
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
177
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
177
178
 
178
179
  # Get cached rms
179
- rstd_row = tl.load(RSTD_ptr)
180
+ rstd_row = tl.load(rstd_base)
180
181
 
181
182
  X_row = X_row.to(tl.float32)
182
183
 
@@ -193,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
193
194
  dX_row = rstd_row * m
194
195
 
195
196
  if has_dS_out:
196
- dS_out_row = tl.load(dS_out_ptr + col_offsets, mask=mask, other=0.0)
197
+ ds_base = dS_out_ptr + row_idx * dS_out_row_stride
198
+ dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
197
199
  dX_row += (rstd_row) * (
198
200
  -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
199
201
  ) + dS_out_row
200
- dS_out_ptr += dS_out_row_stride
201
202
  else:
202
203
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
203
204
 
@@ -208,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
208
209
  # here X_row is already in fp32 (see previous if block)
209
210
  dW_row += dY_row * (X_row * rstd_row)
210
211
 
211
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
212
-
213
- dY_ptr += dY_row_stride
214
- dX_ptr += dX_row_stride
215
- X_ptr += X_row_stride
216
- RSTD_ptr += RSTD_row_stride
212
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
217
213
 
218
214
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
219
215
 
@@ -252,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
252
248
  # XPU-specific optimization
253
249
  kernel_args = {}
254
250
  if X.device.type == "xpu":
255
- kernel_args["grf_mode"] = "large"
251
+ set_large_grf_mode(kernel_args)
256
252
 
257
253
  # TODO: add _block_fused_add_rms_norm_forward_kernel
258
254
  _fused_add_rms_norm_forward_kernel[(n_rows,)](
@@ -293,6 +289,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293
289
  sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294
290
  elif S.device.type == "xpu":
295
291
  sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
292
+ elif S.device.type == "npu":
293
+ sm_count = get_npu_core_count()
296
294
 
297
295
  # fp32 for numerical stability especially.
298
296
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -310,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
310
308
  # XPU-specific optimization
311
309
  kernel_args = {}
312
310
  if S.device.type == "xpu":
313
- kernel_args["grf_mode"] = "large"
311
+ set_large_grf_mode(kernel_args)
314
312
 
315
313
  # TODO: add _block_fused_add_rms_norm_backward_kernel
316
314
  _fused_add_rms_norm_backward_kernel[grid](
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
6
6
  from liger_kernel.ops.utils import amp_custom_fwd
7
7
  from liger_kernel.ops.utils import element_mul_kernel
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
12
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
13
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
- MAX_FUSED_SIZE = 65536 // 2
14
+ MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
14
15
 
15
16
 
16
17
  def fused_linear_cross_entropy_forward(
@@ -27,8 +28,12 @@ def fused_linear_cross_entropy_forward(
27
28
  return_z_loss=False,
28
29
  accum_dtype=None,
29
30
  use_token_scaling=False,
31
+ return_token_accuracy=False,
30
32
  ):
31
33
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
34
+ assert isinstance(return_token_accuracy, bool), (
35
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
36
+ )
32
37
  device = _input.device
33
38
 
34
39
  input_requires_grad = _input.requires_grad
@@ -58,9 +63,13 @@ def fused_linear_cross_entropy_forward(
58
63
  else:
59
64
  grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
60
65
  grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
66
+ else:
67
+ grad_weight = None
68
+ grad_bias = None
61
69
 
62
70
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
63
71
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
72
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
64
73
 
65
74
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
66
75
  target_mask = target != ignore_index
@@ -126,6 +135,7 @@ def fused_linear_cross_entropy_forward(
126
135
  # unreduced loss
127
136
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
128
137
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
138
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
129
139
 
130
140
  # ensure _input and target are contiguous
131
141
  logits_chunk = logits_chunk.contiguous()
@@ -141,6 +151,10 @@ def fused_linear_cross_entropy_forward(
141
151
  loss_ptr=loss_1d_slice,
142
152
  z_loss_ptr=z_loss_1d_slice,
143
153
  loss_stride=loss_1d_slice.stride(-1), # always 1
154
+ token_accuracy_ptr=token_accuracy_1d_slice,
155
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
156
+ if return_token_accuracy
157
+ else 0, # always 1 if accuracy is enabled
144
158
  n_cols=V,
145
159
  n_non_ignore=total_n_non_ignore,
146
160
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -151,6 +165,7 @@ def fused_linear_cross_entropy_forward(
151
165
  reduction=reduction,
152
166
  softcap=softcap,
153
167
  RETURN_Z_LOSS=return_z_loss,
168
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
154
169
  HAS_WEIGHT=True if ce_weight is not None else False,
155
170
  HAS_SOFTCAPPING=True if softcap is not None else False,
156
171
  HAS_GRADIENTS=input_requires_grad,
@@ -167,6 +182,8 @@ def fused_linear_cross_entropy_forward(
167
182
  loss_1d[start_idx:end_idx] = loss_1d_slice
168
183
  if return_z_loss:
169
184
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
185
+ if return_token_accuracy:
186
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
170
187
  grad_logits_chunk = logits_chunk # chunk_size x V
171
188
 
172
189
  # Apply token scaling to gradients if requested
@@ -198,15 +215,18 @@ def fused_linear_cross_entropy_forward(
198
215
  # Return per-token losses
199
216
  loss = loss_1d
200
217
  z_loss = z_loss_1d if return_z_loss else None
218
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
201
219
  else:
202
220
  loss = torch.sum(loss_1d)
203
221
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
222
+ # For accuracy, we compute the mean across all non-ignored tokens
223
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
204
224
 
205
225
  # Cast back to original dtype
206
226
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
207
227
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
208
228
 
209
- return loss, z_loss, grad_input, grad_weight, grad_bias
229
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
210
230
 
211
231
 
212
232
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -274,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
274
294
  return_z_loss: bool = False,
275
295
  accum_dtype=None,
276
296
  use_token_scaling: bool = False,
297
+ return_token_accuracy: bool = False,
277
298
  ):
278
299
  """
279
300
  Fusing the last linear layer with cross-entropy loss
@@ -297,9 +318,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
297
318
  use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
298
319
  When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
299
320
  Default: False.
321
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
300
322
  """
301
323
 
302
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
324
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
303
325
  _input=_input,
304
326
  weight=weight,
305
327
  target=target,
@@ -313,6 +335,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
313
335
  return_z_loss=return_z_loss,
314
336
  accum_dtype=accum_dtype,
315
337
  use_token_scaling=use_token_scaling,
338
+ return_token_accuracy=return_token_accuracy,
316
339
  )
317
340
  # downcast to dtype and store for backward
318
341
  ctx.save_for_backward(
@@ -321,13 +344,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
321
344
  grad_bias.detach() if bias is not None else None,
322
345
  )
323
346
  ctx.return_z_loss = return_z_loss
324
- return loss, z_loss
347
+ ctx.return_token_accuracy = return_token_accuracy
348
+ return loss, z_loss, token_accuracy
325
349
 
326
350
  @staticmethod
327
351
  @amp_custom_bwd
328
- def backward(ctx, grad_output, grad_output2):
352
+ def backward(ctx, grad_output, grad_output2, grad_output3):
329
353
  if ctx.return_z_loss:
330
354
  del grad_output2 # z_loss is only for logging
355
+ if ctx.return_token_accuracy:
356
+ del grad_output3 # token_accuracy is only for metrics
331
357
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
332
358
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
333
359
  grad_output, grad_input, grad_weight, grad_bias
@@ -346,4 +372,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
346
372
  None,
347
373
  None,
348
374
  None, # use_token_scaling
375
+ None, # return_token_accuracy
349
376
  )
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
68
  tanh_result = tanh(tanh_arg)
68
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
69
71
 
70
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
71
73
 
72
74
  # Gradient w.r.t. a can be computed with:
73
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
80
  da_row = dc_row * b_row * (term1 + term2)
79
81
 
80
82
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
82
84
 
83
85
 
84
86
  def geglu_forward(a, b):
@@ -6,8 +6,10 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import infer_device
10
+ from liger_kernel.utils import is_npu_available
9
11
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
13
  try:
12
14
  # typical import path with dispatch available
13
15
  from triton.language.extra.libdevice import rsqrt
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
19
  else:
18
20
  from triton.language.math import rsqrt
19
21
 
20
- MAX_FUSED_SIZE = 65536
22
+ if infer_device() == "npu":
23
+ MAX_FUSED_SIZE = 16384 # 8192
24
+ else:
25
+ MAX_FUSED_SIZE = 65536
21
26
 
22
27
 
23
28
  @triton.jit
@@ -77,15 +82,14 @@ def _group_norm_forward_kernel(
77
82
  for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
83
  W = tl.load(W_ptr + channel_idx)
79
84
  B = tl.load(B_ptr + channel_idx)
80
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
85
+ # Calculate channel offset within the group
86
+ channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
87
+ for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
81
88
  hidden_size_offsets = i + block_range
82
89
  mask = hidden_size_offsets < hidden_size_per_channel
83
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
90
+ X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
84
91
  Y = (X - m) * rstd * W + B
85
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
-
87
- X_ptr += hidden_size_per_channel
88
- Y_ptr += hidden_size_per_channel
92
+ tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
89
93
 
90
94
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
95
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
21
21
  return num_warps
22
22
 
23
23
 
24
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
30
 
26
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
32
 
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
116
121
 
117
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
123
  BT, V = y_pred.shape
119
- BLOCK_SIZE = (
120
- min(8192, triton.next_power_of_2(V))
121
- if infer_device() == "xpu"
122
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
- )
124
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
124
125
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
126
 
126
127
  grid = (BT,)
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
159
160
 
160
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
162
  BT, V = target.shape
162
- BLOCK_SIZE = (
163
- min(8192, triton.next_power_of_2(V))
164
- if infer_device() == "xpu"
165
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
- )
163
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
167
164
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
165
 
169
166
  grid = (BT,)
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import operator
2
3
 
3
4
  import torch
@@ -7,8 +8,11 @@ import triton.language as tl
7
8
  from liger_kernel.ops.utils import calculate_settings
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import get_npu_core_count
12
+ from liger_kernel.ops.utils import set_large_grf_mode
13
+ from liger_kernel.utils import is_npu_available
10
14
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
16
  try:
13
17
  # typical import path with dispatch available
14
18
  from triton.language.extra.libdevice import rsqrt
@@ -85,68 +89,81 @@ def _layer_norm_forward_kernel(
85
89
  @triton.jit
86
90
  def _layer_norm_backward_kernel(
87
91
  X_ptr, # pointer to input, shape (n_rows, n_cols)
92
+ stride_x, # stride of each row in input
88
93
  W_ptr, # pointer to weights, shape (n_cols,)
89
94
  Mean_ptr, # pointer to mean, shape (n_rows,)
95
+ stride_mean, # stride of each row in mean
90
96
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
97
+ stride_rstd, # stride of each row in rstd
91
98
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
99
+ stride_dx, # stride of each row in input grad
92
100
  DW_ptr, # pointer to weights grad, shape (n_cols,)
101
+ stride_dw, # stride of each row in weights grad
93
102
  DB_ptr, # pointer to bias grad, shape (n_cols,)
103
+ stride_db, # stride of each row in bias grad
94
104
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
95
- stride_x, # stride of each row in input
96
- stride_dx, # stride of each row in input grad
97
105
  stride_dy, # stride of each row in output grad
106
+ n_rows,
98
107
  n_cols,
108
+ rows_per_program: tl.constexpr,
99
109
  BLOCK_SIZE: tl.constexpr,
100
- dtype: tl.constexpr,
101
- atomic_dtype: tl.constexpr,
102
110
  ):
103
111
  """
104
112
  References:
105
113
  https://arxiv.org/abs/1607.06450
106
114
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
107
115
  """
108
- row_idx = tl.program_id(0).to(tl.int64)
116
+ row_block_id = tl.program_id(0).to(tl.int64)
117
+ row_start = row_block_id * rows_per_program
118
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
119
  cols = tl.arange(0, BLOCK_SIZE)
110
120
  mask = cols < n_cols
111
121
 
122
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
123
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
124
+
112
125
  # Pre-load weights once (same optimization as forward pass)
113
126
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114
127
  w_f32 = w.to(tl.float32)
115
128
 
116
- # Calculate pointers for this specific row
117
- row_X_ptr = X_ptr + row_idx * stride_x
118
- row_DX_ptr = DX_ptr + row_idx * stride_dx
119
- row_DY_ptr = DY_ptr + row_idx * stride_dy
120
- row_Mean_ptr = Mean_ptr + row_idx
121
- row_RSTD_ptr = RSTD_ptr + row_idx
122
-
123
- # Load data for this row
124
- x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125
- dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126
- mean = tl.load(row_Mean_ptr)
127
- rstd = tl.load(row_RSTD_ptr)
128
-
129
- # Convert to fp32 for numerical stability
130
- x_f32 = x.to(tl.float32)
131
- dy_f32 = dy.to(tl.float32)
132
- mean_f32 = mean.to(tl.float32)
133
- rstd_f32 = rstd.to(tl.float32)
134
-
135
- # Compute backward pass for this row
136
- x_hat = (x_f32 - mean_f32) * rstd_f32
137
- wdy = w_f32 * dy_f32
138
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
- c2 = tl.sum(wdy, axis=0) / n_cols
140
- dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141
-
142
- # Store input gradient
143
- tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144
-
145
- # Accumulate weight and bias gradients using atomic operations
146
- dw = dy_f32 * x_hat
147
- db = dy_f32
148
- tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149
- tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
129
+ for row_idx in range(row_start, row_end):
130
+ # Calculate pointers for this specific row
131
+ row_X_ptr = X_ptr + row_idx * stride_x
132
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
133
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
134
+ row_Mean_ptr = Mean_ptr + row_idx * stride_mean
135
+ row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
136
+
137
+ # Load data for this row
138
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
139
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
140
+ mean = tl.load(row_Mean_ptr)
141
+ rstd = tl.load(row_RSTD_ptr)
142
+
143
+ # Convert to fp32 for numerical stability
144
+ x_f32 = x.to(tl.float32)
145
+ dy_f32 = dy.to(tl.float32)
146
+ mean_f32 = mean.to(tl.float32)
147
+ rstd_f32 = rstd.to(tl.float32)
148
+
149
+ # Compute backward pass for this row
150
+ x_hat = (x_f32 - mean_f32) * rstd_f32
151
+ wdy = w_f32 * dy_f32
152
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
153
+ c2 = tl.sum(wdy, axis=0) / n_cols
154
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
155
+
156
+ # Store input gradient
157
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
158
+
159
+ # Accumulate weight and bias gradients for this thread block's assigned rows
160
+ dw = dy_f32 * x_hat
161
+ db = dy_f32
162
+ dW_row += dw
163
+ db_row += db
164
+
165
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
166
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
150
167
 
151
168
 
152
169
  def layer_norm_forward(X, W, B, eps):
@@ -183,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
183
200
  # XPU-specific optimization
184
201
  kernel_args = {}
185
202
  if X.device.type == "xpu":
186
- kernel_args["grf_mode"] = "large"
203
+ set_large_grf_mode(kernel_args)
187
204
 
188
205
  # Launch kernel with one thread block per row for optimal performance
189
206
  grid = (n_rows,)
@@ -228,60 +245,63 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
228
245
  dY = dY.view(-1, dim)
229
246
  n_rows, n_cols = dY.shape
230
247
 
231
- # Allocate gradient tensors
232
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
- # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234
- grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235
- DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236
- DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
248
+ sm_count = 1
249
+ if X.device.type == "cuda":
250
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
251
+ elif X.device.type == "xpu":
252
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
253
+ elif X.device.type == "npu":
254
+ sm_count = get_npu_core_count()
255
+
256
+ # fp32 for numerical stability especially.
257
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
258
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
237
259
 
238
260
  # Calculate optimal block size and warp configuration
239
261
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
240
262
  if n_cols > BLOCK_SIZE:
241
263
  raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
264
+ rows_per_program = math.ceil(n_rows / sm_count)
265
+ grid = (sm_count,)
242
266
 
243
- # Determine dtype for triton operations
244
- triton_dtype = (
245
- tl.float32
246
- if X.dtype == torch.float32
247
- else tl.bfloat16
248
- if X.dtype == torch.bfloat16
249
- else tl.float16
250
- if X.dtype == torch.float16
251
- else tl.float32 # fallback
252
- )
253
-
254
- # Use float32 for atomic operations if bfloat16 is not supported
255
- atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
267
+ # Allocate gradient tensors
268
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
256
269
 
257
270
  kernel_args = {"num_warps": num_warps}
258
271
  # XPU-specific optimization
259
272
  if X.device.type == "xpu":
260
- kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
273
+ kernel_args.update({"num_warps": 32, "num_stages": 4})
274
+ set_large_grf_mode(kernel_args)
261
275
 
262
276
  # Launch kernel with one thread block per row for optimal performance
263
- grid = (n_rows,)
264
277
  _layer_norm_backward_kernel[grid](
265
278
  X,
279
+ X.stride(0),
266
280
  W,
267
281
  Mean,
282
+ Mean.stride(0),
268
283
  RSTD,
284
+ RSTD.stride(0),
269
285
  DX,
270
- DW,
271
- DB,
272
- dY,
273
- X.stride(0),
274
286
  DX.stride(0),
287
+ _DW,
288
+ _DW.stride(0),
289
+ _DB,
290
+ _DB.stride(0),
291
+ dY,
275
292
  dY.stride(0),
293
+ n_rows,
276
294
  n_cols,
295
+ rows_per_program=rows_per_program,
277
296
  BLOCK_SIZE=BLOCK_SIZE,
278
- dtype=triton_dtype,
279
- atomic_dtype=atomic_dtype,
280
297
  **kernel_args,
281
298
  )
282
299
 
283
300
  DX = DX.view(*shape)
284
- return DX, DW.to(W.dtype), DB.to(W.dtype)
301
+ DW = _DW.sum(dim=0).to(W.dtype)
302
+ DB = _DB.sum(dim=0).to(B.dtype)
303
+
304
+ return DX, DW, DB
285
305
 
286
306
 
287
307
  class LigerLayerNormFunction(torch.autograd.Function):