liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
11
  from liger_kernel.ops.utils import torch_to_triton_dtype
12
+ from liger_kernel.utils import get_npu_multi_processor_count
13
+ from liger_kernel.utils import is_npu_available
12
14
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
16
  try:
15
17
  # typical import path with dispatch available
16
18
  from triton.language.extra.libdevice import rsqrt
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293
295
  sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294
296
  elif S.device.type == "xpu":
295
297
  sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
298
+ elif S.device.type == "npu":
299
+ sm_count = get_npu_multi_processor_count()
296
300
 
297
301
  # fp32 for numerical stability especially.
298
302
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
6
6
  from liger_kernel.ops.utils import amp_custom_fwd
7
7
  from liger_kernel.ops.utils import element_mul_kernel
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
12
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
13
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
- MAX_FUSED_SIZE = 65536 // 2
14
+ MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
14
15
 
15
16
 
16
17
  def fused_linear_cross_entropy_forward(
@@ -27,10 +28,16 @@ def fused_linear_cross_entropy_forward(
27
28
  return_z_loss=False,
28
29
  accum_dtype=None,
29
30
  use_token_scaling=False,
31
+ return_token_accuracy=False,
30
32
  ):
31
33
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
34
+ assert isinstance(return_token_accuracy, bool), (
35
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
36
+ )
32
37
  device = _input.device
33
38
 
39
+ input_requires_grad = _input.requires_grad
40
+
34
41
  # inputs have shape: BT x H
35
42
  # materialized activations will have shape: BT x V
36
43
  # the increase in memory = BT x V
@@ -49,15 +56,20 @@ def fused_linear_cross_entropy_forward(
49
56
  grad_input = torch.zeros_like(_input, device=device)
50
57
 
51
58
  # we use fp32 for loss and gradients accumulator
52
- if accum_dtype is None:
53
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
54
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
59
+ if input_requires_grad:
60
+ if accum_dtype is None:
61
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
62
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
63
+ else:
64
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
65
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
55
66
  else:
56
- grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
57
- grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
67
+ grad_weight = None
68
+ grad_bias = None
58
69
 
59
70
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
60
71
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
72
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
61
73
 
62
74
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
63
75
  target_mask = target != ignore_index
@@ -123,6 +135,7 @@ def fused_linear_cross_entropy_forward(
123
135
  # unreduced loss
124
136
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
125
137
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
138
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
126
139
 
127
140
  # ensure _input and target are contiguous
128
141
  logits_chunk = logits_chunk.contiguous()
@@ -138,6 +151,10 @@ def fused_linear_cross_entropy_forward(
138
151
  loss_ptr=loss_1d_slice,
139
152
  z_loss_ptr=z_loss_1d_slice,
140
153
  loss_stride=loss_1d_slice.stride(-1), # always 1
154
+ token_accuracy_ptr=token_accuracy_1d_slice,
155
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
156
+ if return_token_accuracy
157
+ else 0, # always 1 if accuracy is enabled
141
158
  n_cols=V,
142
159
  n_non_ignore=total_n_non_ignore,
143
160
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -148,9 +165,10 @@ def fused_linear_cross_entropy_forward(
148
165
  reduction=reduction,
149
166
  softcap=softcap,
150
167
  RETURN_Z_LOSS=return_z_loss,
168
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
151
169
  HAS_WEIGHT=True if ce_weight is not None else False,
152
170
  HAS_SOFTCAPPING=True if softcap is not None else False,
153
- HAS_GRADIENTS=_input.requires_grad,
171
+ HAS_GRADIENTS=input_requires_grad,
154
172
  BLOCK_SIZE=BLOCK_SIZE,
155
173
  num_warps=32 if not is_hip() else 16,
156
174
  )
@@ -164,6 +182,8 @@ def fused_linear_cross_entropy_forward(
164
182
  loss_1d[start_idx:end_idx] = loss_1d_slice
165
183
  if return_z_loss:
166
184
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
185
+ if return_token_accuracy:
186
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
167
187
  grad_logits_chunk = logits_chunk # chunk_size x V
168
188
 
169
189
  # Apply token scaling to gradients if requested
@@ -172,12 +192,13 @@ def fused_linear_cross_entropy_forward(
172
192
  scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
173
193
  grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
174
194
 
175
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
195
+ if input_requires_grad:
196
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
176
197
 
177
- if grad_weight is not None and _input.requires_grad:
198
+ if grad_weight is not None and input_requires_grad:
178
199
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
179
200
 
180
- if bias is not None and _input.requires_grad:
201
+ if bias is not None and input_requires_grad:
181
202
  torch.add(
182
203
  input=grad_bias,
183
204
  other=grad_logits_chunk.sum(dim=0),
@@ -194,15 +215,18 @@ def fused_linear_cross_entropy_forward(
194
215
  # Return per-token losses
195
216
  loss = loss_1d
196
217
  z_loss = z_loss_1d if return_z_loss else None
218
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
197
219
  else:
198
220
  loss = torch.sum(loss_1d)
199
221
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
222
+ # For accuracy, we compute the mean across all non-ignored tokens
223
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
200
224
 
201
225
  # Cast back to original dtype
202
226
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
203
227
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
204
228
 
205
- return loss, z_loss, grad_input, grad_weight, grad_bias
229
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
206
230
 
207
231
 
208
232
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -270,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
270
294
  return_z_loss: bool = False,
271
295
  accum_dtype=None,
272
296
  use_token_scaling: bool = False,
297
+ return_token_accuracy: bool = False,
273
298
  ):
274
299
  """
275
300
  Fusing the last linear layer with cross-entropy loss
@@ -293,9 +318,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
293
318
  use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
294
319
  When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
295
320
  Default: False.
321
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
296
322
  """
297
323
 
298
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
324
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
299
325
  _input=_input,
300
326
  weight=weight,
301
327
  target=target,
@@ -309,6 +335,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
309
335
  return_z_loss=return_z_loss,
310
336
  accum_dtype=accum_dtype,
311
337
  use_token_scaling=use_token_scaling,
338
+ return_token_accuracy=return_token_accuracy,
312
339
  )
313
340
  # downcast to dtype and store for backward
314
341
  ctx.save_for_backward(
@@ -317,13 +344,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
317
344
  grad_bias.detach() if bias is not None else None,
318
345
  )
319
346
  ctx.return_z_loss = return_z_loss
320
- return loss, z_loss
347
+ ctx.return_token_accuracy = return_token_accuracy
348
+ return loss, z_loss, token_accuracy
321
349
 
322
350
  @staticmethod
323
351
  @amp_custom_bwd
324
- def backward(ctx, grad_output, grad_output2):
352
+ def backward(ctx, grad_output, grad_output2, grad_output3):
325
353
  if ctx.return_z_loss:
326
354
  del grad_output2 # z_loss is only for logging
355
+ if ctx.return_token_accuracy:
356
+ del grad_output3 # token_accuracy is only for metrics
327
357
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
328
358
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
329
359
  grad_output, grad_input, grad_weight, grad_bias
@@ -342,4 +372,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
342
372
  None,
343
373
  None,
344
374
  None, # use_token_scaling
375
+ None, # return_token_accuracy
345
376
  )
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
68
  tanh_result = tanh(tanh_arg)
68
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
69
71
 
70
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
71
73
 
72
74
  # Gradient w.r.t. a can be computed with:
73
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
80
  da_row = dc_row * b_row * (term1 + term2)
79
81
 
80
82
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
82
84
 
83
85
 
84
86
  def geglu_forward(a, b):
@@ -6,8 +6,9 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import is_npu_available
9
10
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
11
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
12
  try:
12
13
  # typical import path with dispatch available
13
14
  from triton.language.extra.libdevice import rsqrt
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import operator
2
3
 
3
4
  import torch
@@ -7,8 +8,9 @@ import triton.language as tl
7
8
  from liger_kernel.ops.utils import calculate_settings
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import rsqrt
@@ -85,68 +87,87 @@ def _layer_norm_forward_kernel(
85
87
  @triton.jit
86
88
  def _layer_norm_backward_kernel(
87
89
  X_ptr, # pointer to input, shape (n_rows, n_cols)
90
+ stride_x, # stride of each row in input
88
91
  W_ptr, # pointer to weights, shape (n_cols,)
89
92
  Mean_ptr, # pointer to mean, shape (n_rows,)
93
+ stride_mean, # stride of each row in mean
90
94
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
95
+ stride_rstd, # stride of each row in rstd
91
96
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
97
+ stride_dx, # stride of each row in input grad
92
98
  DW_ptr, # pointer to weights grad, shape (n_cols,)
99
+ stride_dw, # stride of each row in weights grad
93
100
  DB_ptr, # pointer to bias grad, shape (n_cols,)
101
+ stride_db, # stride of each row in bias grad
94
102
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
95
- stride_x, # stride of each row in input
96
- stride_dx, # stride of each row in input grad
97
103
  stride_dy, # stride of each row in output grad
104
+ n_rows,
98
105
  n_cols,
106
+ rows_per_program: tl.constexpr,
99
107
  BLOCK_SIZE: tl.constexpr,
100
- dtype: tl.constexpr,
101
- atomic_dtype: tl.constexpr,
102
108
  ):
103
109
  """
104
110
  References:
105
111
  https://arxiv.org/abs/1607.06450
106
112
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
107
113
  """
108
- row_idx = tl.program_id(0).to(tl.int64)
114
+ row_block_id = tl.program_id(0).to(tl.int64)
115
+ row_start = row_block_id * rows_per_program
116
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
117
  cols = tl.arange(0, BLOCK_SIZE)
110
118
  mask = cols < n_cols
111
119
 
120
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
122
+
112
123
  # Pre-load weights once (same optimization as forward pass)
113
124
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114
125
  w_f32 = w.to(tl.float32)
115
126
 
116
127
  # Calculate pointers for this specific row
117
- row_X_ptr = X_ptr + row_idx * stride_x
118
- row_DX_ptr = DX_ptr + row_idx * stride_dx
119
- row_DY_ptr = DY_ptr + row_idx * stride_dy
120
- row_Mean_ptr = Mean_ptr + row_idx
121
- row_RSTD_ptr = RSTD_ptr + row_idx
122
-
123
- # Load data for this row
124
- x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125
- dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126
- mean = tl.load(row_Mean_ptr)
127
- rstd = tl.load(row_RSTD_ptr)
128
-
129
- # Convert to fp32 for numerical stability
130
- x_f32 = x.to(tl.float32)
131
- dy_f32 = dy.to(tl.float32)
132
- mean_f32 = mean.to(tl.float32)
133
- rstd_f32 = rstd.to(tl.float32)
134
-
135
- # Compute backward pass for this row
136
- x_hat = (x_f32 - mean_f32) * rstd_f32
137
- wdy = w_f32 * dy_f32
138
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
- c2 = tl.sum(wdy, axis=0) / n_cols
140
- dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141
-
142
- # Store input gradient
143
- tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144
-
145
- # Accumulate weight and bias gradients using atomic operations
146
- dw = dy_f32 * x_hat
147
- db = dy_f32
148
- tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149
- tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
128
+ row_X_ptr = X_ptr + row_start * stride_x
129
+ row_DX_ptr = DX_ptr + row_start * stride_dx
130
+ row_DY_ptr = DY_ptr + row_start * stride_dy
131
+ row_Mean_ptr = Mean_ptr + row_start
132
+ row_RSTD_ptr = RSTD_ptr + row_start
133
+
134
+ for _ in range(row_start, row_end):
135
+ # Load data for this row
136
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
138
+ mean = tl.load(row_Mean_ptr)
139
+ rstd = tl.load(row_RSTD_ptr)
140
+
141
+ # Convert to fp32 for numerical stability
142
+ x_f32 = x.to(tl.float32)
143
+ dy_f32 = dy.to(tl.float32)
144
+ mean_f32 = mean.to(tl.float32)
145
+ rstd_f32 = rstd.to(tl.float32)
146
+
147
+ # Compute backward pass for this row
148
+ x_hat = (x_f32 - mean_f32) * rstd_f32
149
+ wdy = w_f32 * dy_f32
150
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
151
+ c2 = tl.sum(wdy, axis=0) / n_cols
152
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
153
+
154
+ # Store input gradient
155
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
156
+
157
+ # Accumulate weight and bias gradients for this thread block's assigned rows
158
+ dw = dy_f32 * x_hat
159
+ db = dy_f32
160
+ dW_row += dw
161
+ db_row += db
162
+
163
+ row_X_ptr += stride_x
164
+ row_DX_ptr += stride_dx
165
+ row_DY_ptr += stride_dy
166
+ row_Mean_ptr += stride_mean
167
+ row_RSTD_ptr += stride_rstd
168
+
169
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
150
171
 
151
172
 
152
173
  def layer_norm_forward(X, W, B, eps):
@@ -228,31 +249,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
228
249
  dY = dY.view(-1, dim)
229
250
  n_rows, n_cols = dY.shape
230
251
 
231
- # Allocate gradient tensors
232
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
- # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234
- grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235
- DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236
- DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
252
+ sm_count = 1
253
+ if X.device.type == "cuda":
254
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
255
+ elif X.device.type == "xpu":
256
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
257
+
258
+ # fp32 for numerical stability especially.
259
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
260
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
237
261
 
238
262
  # Calculate optimal block size and warp configuration
239
263
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
240
264
  if n_cols > BLOCK_SIZE:
241
265
  raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
266
+ rows_per_program = math.ceil(n_rows / sm_count)
267
+ grid = (sm_count,)
242
268
 
243
- # Determine dtype for triton operations
244
- triton_dtype = (
245
- tl.float32
246
- if X.dtype == torch.float32
247
- else tl.bfloat16
248
- if X.dtype == torch.bfloat16
249
- else tl.float16
250
- if X.dtype == torch.float16
251
- else tl.float32 # fallback
252
- )
253
-
254
- # Use float32 for atomic operations if bfloat16 is not supported
255
- atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
269
+ # Allocate gradient tensors
270
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
256
271
 
257
272
  kernel_args = {"num_warps": num_warps}
258
273
  # XPU-specific optimization
@@ -260,28 +275,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
260
275
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
261
276
 
262
277
  # Launch kernel with one thread block per row for optimal performance
263
- grid = (n_rows,)
264
278
  _layer_norm_backward_kernel[grid](
265
279
  X,
280
+ X.stride(0),
266
281
  W,
267
282
  Mean,
283
+ Mean.stride(0),
268
284
  RSTD,
285
+ RSTD.stride(0),
269
286
  DX,
270
- DW,
271
- DB,
272
- dY,
273
- X.stride(0),
274
287
  DX.stride(0),
288
+ _DW,
289
+ _DW.stride(0),
290
+ _DB,
291
+ _DB.stride(0),
292
+ dY,
275
293
  dY.stride(0),
294
+ n_rows,
276
295
  n_cols,
296
+ rows_per_program=rows_per_program,
277
297
  BLOCK_SIZE=BLOCK_SIZE,
278
- dtype=triton_dtype,
279
- atomic_dtype=atomic_dtype,
280
298
  **kernel_args,
281
299
  )
282
300
 
283
301
  DX = DX.view(*shape)
284
- return DX, DW.to(W.dtype), DB.to(W.dtype)
302
+ DW = _DW.sum(dim=0).to(W.dtype)
303
+ DB = _DB.sum(dim=0).to(B.dtype)
304
+ return DX, DW, DB
285
305
 
286
306
 
287
307
  class LigerLayerNormFunction(torch.autograd.Function):