liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +120 -63
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +88 -70
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +33 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +122 -0
  24. liger_kernel/transformers/model/gemma.py +19 -7
  25. liger_kernel/transformers/model/gemma2.py +22 -7
  26. liger_kernel/transformers/model/gemma3.py +52 -14
  27. liger_kernel/transformers/model/glm4.py +18 -5
  28. liger_kernel/transformers/model/glm4v.py +18 -5
  29. liger_kernel/transformers/model/glm4v_moe.py +25 -5
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +157 -0
  32. liger_kernel/transformers/model/llama.py +16 -6
  33. liger_kernel/transformers/model/llama4.py +18 -5
  34. liger_kernel/transformers/model/llava.py +18 -6
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +17 -7
  37. liger_kernel/transformers/model/mixtral.py +24 -9
  38. liger_kernel/transformers/model/mllama.py +14 -5
  39. liger_kernel/transformers/model/olmo2.py +18 -5
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +41 -5
  43. liger_kernel/transformers/model/phi3.py +16 -8
  44. liger_kernel/transformers/model/qwen2.py +18 -4
  45. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  46. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  47. liger_kernel/transformers/model/qwen3.py +22 -6
  48. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +17 -7
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +729 -4
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
  64. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import operator
2
3
 
3
4
  import torch
@@ -7,8 +8,9 @@ import triton.language as tl
7
8
  from liger_kernel.ops.utils import calculate_settings
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import rsqrt
@@ -63,12 +65,11 @@ def _layer_norm_forward_kernel(
63
65
  X_f32 = X_row.to(tl.float32)
64
66
 
65
67
  # Compute statistics in fp32 for numerical stability
66
- n_cols_f32 = n_cols.to(tl.float32)
67
- mean = tl.sum(X_f32, axis=0) / n_cols_f32
68
+ mean = tl.sum(X_f32, axis=0) / n_cols
68
69
  X_centered = X_f32 - mean
69
70
  # Apply mask to variance calculation to exclude contributions from masked elements
70
71
  X_centered_masked = tl.where(mask, X_centered, 0.0)
71
- var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
72
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
72
73
  rstd = rsqrt(var + eps)
73
74
 
74
75
  # Store statistics (convert back to original dtype only once)
@@ -86,69 +87,87 @@ def _layer_norm_forward_kernel(
86
87
  @triton.jit
87
88
  def _layer_norm_backward_kernel(
88
89
  X_ptr, # pointer to input, shape (n_rows, n_cols)
90
+ stride_x, # stride of each row in input
89
91
  W_ptr, # pointer to weights, shape (n_cols,)
90
92
  Mean_ptr, # pointer to mean, shape (n_rows,)
93
+ stride_mean, # stride of each row in mean
91
94
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
95
+ stride_rstd, # stride of each row in rstd
92
96
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
97
+ stride_dx, # stride of each row in input grad
93
98
  DW_ptr, # pointer to weights grad, shape (n_cols,)
99
+ stride_dw, # stride of each row in weights grad
94
100
  DB_ptr, # pointer to bias grad, shape (n_cols,)
101
+ stride_db, # stride of each row in bias grad
95
102
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
96
- stride_x, # stride of each row in input
97
- stride_dx, # stride of each row in input grad
98
103
  stride_dy, # stride of each row in output grad
104
+ n_rows,
99
105
  n_cols,
106
+ rows_per_program: tl.constexpr,
100
107
  BLOCK_SIZE: tl.constexpr,
101
- dtype: tl.constexpr,
102
- atomic_dtype: tl.constexpr,
103
108
  ):
104
109
  """
105
110
  References:
106
111
  https://arxiv.org/abs/1607.06450
107
112
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
108
113
  """
109
- row_idx = tl.program_id(0).to(tl.int64)
114
+ row_block_id = tl.program_id(0).to(tl.int64)
115
+ row_start = row_block_id * rows_per_program
116
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
110
117
  cols = tl.arange(0, BLOCK_SIZE)
111
118
  mask = cols < n_cols
112
119
 
120
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
122
+
113
123
  # Pre-load weights once (same optimization as forward pass)
114
124
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
125
  w_f32 = w.to(tl.float32)
116
- n_cols_f32 = n_cols.to(tl.float32)
117
126
 
118
127
  # Calculate pointers for this specific row
119
- row_X_ptr = X_ptr + row_idx * stride_x
120
- row_DX_ptr = DX_ptr + row_idx * stride_dx
121
- row_DY_ptr = DY_ptr + row_idx * stride_dy
122
- row_Mean_ptr = Mean_ptr + row_idx
123
- row_RSTD_ptr = RSTD_ptr + row_idx
124
-
125
- # Load data for this row
126
- x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
127
- dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
128
- mean = tl.load(row_Mean_ptr)
129
- rstd = tl.load(row_RSTD_ptr)
130
-
131
- # Convert to fp32 for numerical stability
132
- x_f32 = x.to(tl.float32)
133
- dy_f32 = dy.to(tl.float32)
134
- mean_f32 = mean.to(tl.float32)
135
- rstd_f32 = rstd.to(tl.float32)
136
-
137
- # Compute backward pass for this row
138
- x_hat = (x_f32 - mean_f32) * rstd_f32
139
- wdy = w_f32 * dy_f32
140
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
- c2 = tl.sum(wdy, axis=0) / n_cols_f32
142
- dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
-
144
- # Store input gradient
145
- tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
146
-
147
- # Accumulate weight and bias gradients using atomic operations
148
- dw = dy_f32 * x_hat
149
- db = dy_f32
150
- tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
151
- tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
128
+ row_X_ptr = X_ptr + row_start * stride_x
129
+ row_DX_ptr = DX_ptr + row_start * stride_dx
130
+ row_DY_ptr = DY_ptr + row_start * stride_dy
131
+ row_Mean_ptr = Mean_ptr + row_start
132
+ row_RSTD_ptr = RSTD_ptr + row_start
133
+
134
+ for _ in range(row_start, row_end):
135
+ # Load data for this row
136
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
138
+ mean = tl.load(row_Mean_ptr)
139
+ rstd = tl.load(row_RSTD_ptr)
140
+
141
+ # Convert to fp32 for numerical stability
142
+ x_f32 = x.to(tl.float32)
143
+ dy_f32 = dy.to(tl.float32)
144
+ mean_f32 = mean.to(tl.float32)
145
+ rstd_f32 = rstd.to(tl.float32)
146
+
147
+ # Compute backward pass for this row
148
+ x_hat = (x_f32 - mean_f32) * rstd_f32
149
+ wdy = w_f32 * dy_f32
150
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
151
+ c2 = tl.sum(wdy, axis=0) / n_cols
152
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
153
+
154
+ # Store input gradient
155
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
156
+
157
+ # Accumulate weight and bias gradients for this thread block's assigned rows
158
+ dw = dy_f32 * x_hat
159
+ db = dy_f32
160
+ dW_row += dw
161
+ db_row += db
162
+
163
+ row_X_ptr += stride_x
164
+ row_DX_ptr += stride_dx
165
+ row_DY_ptr += stride_dy
166
+ row_Mean_ptr += stride_mean
167
+ row_RSTD_ptr += stride_rstd
168
+
169
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
152
171
 
153
172
 
154
173
  def layer_norm_forward(X, W, B, eps):
@@ -230,31 +249,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
230
249
  dY = dY.view(-1, dim)
231
250
  n_rows, n_cols = dY.shape
232
251
 
233
- # Allocate gradient tensors
234
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
235
- # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
236
- grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
237
- DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
238
- DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
252
+ sm_count = 1
253
+ if X.device.type == "cuda":
254
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
255
+ elif X.device.type == "xpu":
256
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
257
+
258
+ # fp32 for numerical stability especially.
259
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
260
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
239
261
 
240
262
  # Calculate optimal block size and warp configuration
241
263
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
242
264
  if n_cols > BLOCK_SIZE:
243
265
  raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
266
+ rows_per_program = math.ceil(n_rows / sm_count)
267
+ grid = (sm_count,)
244
268
 
245
- # Determine dtype for triton operations
246
- triton_dtype = (
247
- tl.float32
248
- if X.dtype == torch.float32
249
- else tl.bfloat16
250
- if X.dtype == torch.bfloat16
251
- else tl.float16
252
- if X.dtype == torch.float16
253
- else tl.float32 # fallback
254
- )
255
-
256
- # Use float32 for atomic operations if bfloat16 is not supported
257
- atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
269
+ # Allocate gradient tensors
270
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
258
271
 
259
272
  kernel_args = {"num_warps": num_warps}
260
273
  # XPU-specific optimization
@@ -262,28 +275,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
262
275
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
263
276
 
264
277
  # Launch kernel with one thread block per row for optimal performance
265
- grid = (n_rows,)
266
278
  _layer_norm_backward_kernel[grid](
267
279
  X,
280
+ X.stride(0),
268
281
  W,
269
282
  Mean,
283
+ Mean.stride(0),
270
284
  RSTD,
285
+ RSTD.stride(0),
271
286
  DX,
272
- DW,
273
- DB,
274
- dY,
275
- X.stride(0),
276
287
  DX.stride(0),
288
+ _DW,
289
+ _DW.stride(0),
290
+ _DB,
291
+ _DB.stride(0),
292
+ dY,
277
293
  dY.stride(0),
294
+ n_rows,
278
295
  n_cols,
296
+ rows_per_program=rows_per_program,
279
297
  BLOCK_SIZE=BLOCK_SIZE,
280
- dtype=triton_dtype,
281
- atomic_dtype=atomic_dtype,
282
298
  **kernel_args,
283
299
  )
284
300
 
285
301
  DX = DX.view(*shape)
286
- return DX, DW.to(W.dtype), DB.to(W.dtype)
302
+ DW = _DW.sum(dim=0).to(W.dtype)
303
+ DB = _DB.sum(dim=0).to(B.dtype)
304
+ return DX, DW, DB
287
305
 
288
306
 
289
307
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,390 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+
22
+ @triton.jit
23
+ def _poly_norm_forward_kernel(
24
+ Y_ptr,
25
+ Y_row_stride,
26
+ X_ptr,
27
+ X_row_stride,
28
+ W_ptr, # weight: [3] for [w0, w1, w2]
29
+ B_ptr, # bias: scalar
30
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
31
+ RSTD_row_stride,
32
+ n_cols,
33
+ eps,
34
+ BLOCK_SIZE: tl.constexpr,
35
+ ):
36
+ """
37
+ PolyNorm formula:
38
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
39
+ where norm(u) = u / sqrt(mean(u²) + ε)
40
+
41
+ Reference:
42
+ 1. https://github.com/BryceZhuo/PolyCom/
43
+ 2. https://arxiv.org/pdf/2411.03884
44
+
45
+ Cache rstd values for backward pass
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Load pointers
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ RSTD_ptr += row_idx * RSTD_row_stride
55
+
56
+ # Load input row
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
58
+
59
+ # Load weights and bias
60
+ w0 = tl.load(W_ptr + 0)
61
+ w1 = tl.load(W_ptr + 1)
62
+ w2 = tl.load(W_ptr + 2)
63
+ b = tl.load(B_ptr)
64
+
65
+ # Compute x³, x², x
66
+ X_pow3 = X_row * X_row * X_row
67
+ X_pow2 = X_row * X_row
68
+ X_pow1 = X_row
69
+
70
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
71
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
72
+ rstd_3 = rsqrt(mean_square_3 + eps)
73
+ norm_x3 = X_pow3 * rstd_3
74
+
75
+ # Compute norm(x²)
76
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
77
+ rstd_2 = rsqrt(mean_square_2 + eps)
78
+ norm_x2 = X_pow2 * rstd_2
79
+
80
+ # Compute norm(x)
81
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
82
+ rstd_1 = rsqrt(mean_square_1 + eps)
83
+ norm_x1 = X_pow1 * rstd_1
84
+
85
+ # Cache rstd values for backward
86
+ tl.store(RSTD_ptr + 0, rstd_3)
87
+ tl.store(RSTD_ptr + 1, rstd_2)
88
+ tl.store(RSTD_ptr + 2, rstd_1)
89
+
90
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
91
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
92
+
93
+ # Store output
94
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
95
+
96
+
97
+ @triton.jit
98
+ def _poly_norm_backward_kernel(
99
+ dY_ptr,
100
+ dY_row_stride,
101
+ dX_ptr,
102
+ dX_row_stride,
103
+ X_ptr,
104
+ X_row_stride,
105
+ W_ptr,
106
+ RSTD_ptr,
107
+ RSTD_row_stride,
108
+ dW_ptr, # shape: (n_programs, 3)
109
+ dW_row_stride,
110
+ dB_ptr, # shape: (n_programs,)
111
+ n_rows,
112
+ n_cols,
113
+ rows_per_program: tl.constexpr,
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ """
117
+ PolyNorm Backward Kernel Gradient:
118
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
119
+
120
+ where:
121
+ - D_p = RMS(x^p) = 1/rstd_p
122
+ - S_p = sum(grad * x^p) over the row
123
+ - d = n_cols
124
+ - p ∈ {3, 2, 1}
125
+ """
126
+ row_block_id = tl.program_id(0).to(tl.int64)
127
+ row_start = row_block_id * rows_per_program
128
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
+ col_offsets = tl.arange(0, BLOCK_SIZE)
130
+ mask = col_offsets < n_cols
131
+
132
+ # Initialize accumulators for weight and bias gradients (scalars)
133
+ dW0_acc = 0.0
134
+ dW1_acc = 0.0
135
+ dW2_acc = 0.0
136
+ dB_acc = 0.0
137
+
138
+ # Load weights
139
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
140
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
141
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
142
+
143
+ dY_ptr += row_start * dY_row_stride
144
+ dX_ptr += row_start * dX_row_stride
145
+ X_ptr += row_start * X_row_stride
146
+ RSTD_ptr += row_start * RSTD_row_stride
147
+
148
+ for _ in range(row_start, row_end):
149
+ # Load input and gradient
150
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
+
153
+ # Load cached rstd values
154
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
157
+
158
+ # Compute powers
159
+ X_pow3 = X_row * X_row * X_row
160
+ X_pow2 = X_row * X_row
161
+ X_pow1 = X_row
162
+
163
+ # Accumulate bias gradient: dB = sum(dY)
164
+ dB_acc += tl.sum(dY_row, axis=0)
165
+
166
+ # Compute gradient w.r.t. input using closed-form formula
167
+ # For p=3: ∂L/∂x from w0 * norm(x³)
168
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
169
+ grad_x_3 = w0 * (
170
+ 3.0 * X_pow2 * rstd_3 * dY_row
171
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
172
+ )
173
+
174
+ # For p=2: ∂L/∂x from w1 * norm(x²)
175
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
176
+ grad_x_2 = w1 * (
177
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
178
+ )
179
+
180
+ # For p=1: ∂L/∂x from w2 * norm(x)
181
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
182
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
183
+
184
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
185
+ dW0_acc += rstd_3 * S_3
186
+ dW1_acc += rstd_2 * S_2
187
+ dW2_acc += rstd_1 * S_1
188
+
189
+ # Total gradient
190
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
+
192
+ # Store gradient
193
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
+
195
+ # Update pointers
196
+ dY_ptr += dY_row_stride
197
+ dX_ptr += dX_row_stride
198
+ X_ptr += X_row_stride
199
+ RSTD_ptr += RSTD_row_stride
200
+
201
+ # Store accumulated gradients (scalars)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
203
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
204
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
205
+ tl.store(dB_ptr + row_block_id, dB_acc)
206
+
207
+
208
+ def poly_norm_forward(X, W, B, eps=1e-6):
209
+ """
210
+ PolyNorm Forward Pass
211
+
212
+ Args:
213
+ X: input tensor of shape (*, H) where H is hidden dimension
214
+ W: weight tensor of shape (3,) for [w0, w1, w2]
215
+ B: bias scalar tensor
216
+ eps: epsilon for numerical stability
217
+
218
+ Returns:
219
+ Y: output tensor of same shape as X
220
+ X: reshaped input (for backward)
221
+ RSTD: cached rstd values (for backward)
222
+ BLOCK_SIZE: block size used
223
+ num_warps: number of warps used
224
+ """
225
+ shape = X.shape
226
+ dim = shape[-1]
227
+ X = X.view(-1, dim)
228
+ n_rows, n_cols = X.shape
229
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
230
+
231
+ # RSTD is to cache rstd for each row
232
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
234
+
235
+ # Check constraints
236
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
237
+ assert B.numel() == 1, "Bias must be a scalar"
238
+
239
+ # XPU-specific optimization
240
+ kernel_args = {}
241
+ if X.device.type == "xpu":
242
+ kernel_args["grf_mode"] = "large"
243
+
244
+ # Launch kernel
245
+ _poly_norm_forward_kernel[(n_rows,)](
246
+ Y,
247
+ Y.stride(0),
248
+ X,
249
+ X.stride(0),
250
+ W,
251
+ B,
252
+ RSTD,
253
+ RSTD.stride(0),
254
+ n_cols,
255
+ eps,
256
+ BLOCK_SIZE=BLOCK_SIZE,
257
+ num_warps=num_warps,
258
+ **kernel_args,
259
+ )
260
+
261
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
262
+
263
+
264
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
265
+ """
266
+ PolyNorm Backward Pass
267
+
268
+ Args:
269
+ dY: gradient of output
270
+ X: input tensor (already reshaped to 2D)
271
+ W: weight tensor
272
+ RSTD: cached rstd values from forward
273
+ BLOCK_SIZE: block size from forward
274
+ num_warps: number of warps from forward
275
+ in_place: whether to in-place modify dY to store dX (saves memory)
276
+
277
+ Returns:
278
+ dX: gradient w.r.t. input
279
+ dW: gradient w.r.t. weight
280
+ dB: gradient w.r.t. bias
281
+ """
282
+ shape = dY.shape
283
+ dim = shape[-1]
284
+ dY = dY.view(-1, dim)
285
+ n_rows, n_cols = dY.shape
286
+
287
+ # Get number of SMs for parallelization
288
+ import math
289
+
290
+ sm_count = 1
291
+ if X.device.type == "cuda":
292
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
293
+ elif X.device.type == "xpu":
294
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295
+ elif X.device.type == "npu":
296
+ sm_count = get_npu_multi_processor_count()
297
+
298
+ # Allocate or reuse gradients
299
+ if in_place is True:
300
+ dX = dY
301
+ else:
302
+ dX = torch.zeros_like(dY)
303
+
304
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
305
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
306
+
307
+ rows_per_program = math.ceil(n_rows / sm_count)
308
+ grid = (sm_count,)
309
+
310
+ # XPU-specific optimization
311
+ kernel_args = {}
312
+ if X.device.type == "xpu":
313
+ kernel_args["grf_mode"] = "large"
314
+
315
+ # Launch backward kernel
316
+ _poly_norm_backward_kernel[grid](
317
+ dY,
318
+ dY.stride(0),
319
+ dX,
320
+ dX.stride(0),
321
+ X,
322
+ X.stride(0),
323
+ W,
324
+ RSTD,
325
+ RSTD.stride(0),
326
+ _dW,
327
+ _dW.stride(0),
328
+ _dB,
329
+ n_rows,
330
+ n_cols,
331
+ rows_per_program,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ num_warps=num_warps,
334
+ **kernel_args,
335
+ )
336
+
337
+ # Reduce gradients across SMs
338
+ dX = dX.view(*shape)
339
+ dW = _dW.sum(dim=0).to(W.dtype)
340
+ dB = _dB.sum().to(W.dtype)
341
+
342
+ return dX, dW, dB
343
+
344
+
345
+ class LigerPolyNormFunction(torch.autograd.Function):
346
+ """
347
+ PolyNorm Function with forward and backward pass
348
+
349
+ PolyNorm formula:
350
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
351
+ where norm(u) = u / sqrt(mean(u²) + ε)
352
+
353
+ Backward uses closed-form gradient:
354
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
355
+ """
356
+
357
+ @staticmethod
358
+ @ensure_contiguous
359
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
360
+ """
361
+ Args:
362
+ X: input tensor of shape (B, T, H) or (BxT, H)
363
+ W: weight tensor of shape (3,) for [w0, w1, w2]
364
+ B: bias scalar
365
+ eps: epsilon for numerical stability
366
+ in_place: whether to in-place modify grad_output in backward (saves memory)
367
+
368
+ Returns:
369
+ Y: output tensor of same shape as X
370
+ """
371
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
372
+ ctx.BLOCK_SIZE = BLOCK_SIZE
373
+ ctx.num_warps = num_warps
374
+ ctx.in_place = in_place
375
+ ctx.save_for_backward(X, W, RSTD)
376
+ return Y
377
+
378
+ @staticmethod
379
+ @ensure_contiguous
380
+ def backward(ctx, grad_output):
381
+ """
382
+ Args:
383
+ grad_output: gradient of output
384
+
385
+ Returns:
386
+ dX, dW, dB: gradients w.r.t. X, W, B
387
+ """
388
+ X, W, RSTD = ctx.saved_tensors
389
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
390
+ return dX, dW, dB, None, None
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -349,7 +351,8 @@ def _block_rms_norm_backward_kernel(
349
351
 
350
352
  # calculate the gradient of W
351
353
  if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
354
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
353
356
  else:
354
357
  # here X_row is already in fp32 (see previous if block)
355
358
  dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
@@ -449,6 +452,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
452
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
453
  elif X.device.type == "xpu":
451
454
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
455
+ elif X.device.type == "npu":
456
+ sm_count = get_npu_multi_processor_count()
452
457
 
453
458
  # fp32 for numerical stability especially.
454
459
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)