liger-kernel 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  5. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  6. liger_kernel/ops/cross_entropy.py +118 -62
  7. liger_kernel/ops/fused_linear_cross_entropy.py +97 -13
  8. liger_kernel/ops/grpo_loss.py +3 -1
  9. liger_kernel/ops/layer_norm.py +86 -69
  10. liger_kernel/ops/poly_norm.py +386 -0
  11. liger_kernel/ops/tiled_mlp.py +136 -0
  12. liger_kernel/transformers/__init__.py +36 -0
  13. liger_kernel/transformers/cross_entropy.py +8 -3
  14. liger_kernel/transformers/functional.py +31 -6
  15. liger_kernel/transformers/fused_linear_cross_entropy.py +13 -4
  16. liger_kernel/transformers/grpo_loss.py +56 -1
  17. liger_kernel/transformers/model/falcon_h1.py +122 -0
  18. liger_kernel/transformers/model/gemma.py +19 -7
  19. liger_kernel/transformers/model/gemma2.py +22 -7
  20. liger_kernel/transformers/model/gemma3.py +52 -14
  21. liger_kernel/transformers/model/glm4.py +18 -5
  22. liger_kernel/transformers/model/glm4v.py +19 -6
  23. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  24. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  25. liger_kernel/transformers/model/internvl.py +157 -0
  26. liger_kernel/transformers/model/llama.py +16 -6
  27. liger_kernel/transformers/model/llama4.py +18 -5
  28. liger_kernel/transformers/model/llava.py +18 -6
  29. liger_kernel/transformers/model/loss_utils.py +32 -3
  30. liger_kernel/transformers/model/mistral.py +17 -7
  31. liger_kernel/transformers/model/mixtral.py +24 -9
  32. liger_kernel/transformers/model/mllama.py +14 -5
  33. liger_kernel/transformers/model/olmo2.py +18 -5
  34. liger_kernel/transformers/model/olmo3.py +142 -0
  35. liger_kernel/transformers/model/output_classes.py +147 -0
  36. liger_kernel/transformers/model/paligemma.py +41 -5
  37. liger_kernel/transformers/model/phi3.py +16 -8
  38. liger_kernel/transformers/model/qwen2.py +18 -4
  39. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  40. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  41. liger_kernel/transformers/model/qwen3.py +22 -6
  42. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  43. liger_kernel/transformers/model/qwen3_next.py +146 -0
  44. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  45. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  46. liger_kernel/transformers/model/smollm3.py +17 -7
  47. liger_kernel/transformers/model/smolvlm.py +158 -0
  48. liger_kernel/transformers/monkey_patch.py +830 -3
  49. liger_kernel/transformers/multi_token_attention.py +1 -1
  50. liger_kernel/transformers/poly_norm.py +42 -0
  51. liger_kernel/transformers/rms_norm.py +7 -0
  52. liger_kernel/transformers/rope.py +43 -0
  53. liger_kernel/transformers/swiglu.py +17 -0
  54. liger_kernel/transformers/tiled_mlp.py +133 -0
  55. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +16 -10
  56. liger_kernel-0.6.4.dist-info/RECORD +118 -0
  57. liger_kernel-0.6.2.dist-info/RECORD +0 -104
  58. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
  59. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
  60. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
  61. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import operator
2
3
 
3
4
  import torch
@@ -63,12 +64,11 @@ def _layer_norm_forward_kernel(
63
64
  X_f32 = X_row.to(tl.float32)
64
65
 
65
66
  # Compute statistics in fp32 for numerical stability
66
- n_cols_f32 = n_cols.to(tl.float32)
67
- mean = tl.sum(X_f32, axis=0) / n_cols_f32
67
+ mean = tl.sum(X_f32, axis=0) / n_cols
68
68
  X_centered = X_f32 - mean
69
69
  # Apply mask to variance calculation to exclude contributions from masked elements
70
70
  X_centered_masked = tl.where(mask, X_centered, 0.0)
71
- var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
71
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
72
72
  rstd = rsqrt(var + eps)
73
73
 
74
74
  # Store statistics (convert back to original dtype only once)
@@ -86,69 +86,87 @@ def _layer_norm_forward_kernel(
86
86
  @triton.jit
87
87
  def _layer_norm_backward_kernel(
88
88
  X_ptr, # pointer to input, shape (n_rows, n_cols)
89
+ stride_x, # stride of each row in input
89
90
  W_ptr, # pointer to weights, shape (n_cols,)
90
91
  Mean_ptr, # pointer to mean, shape (n_rows,)
92
+ stride_mean, # stride of each row in mean
91
93
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
94
+ stride_rstd, # stride of each row in rstd
92
95
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
96
+ stride_dx, # stride of each row in input grad
93
97
  DW_ptr, # pointer to weights grad, shape (n_cols,)
98
+ stride_dw, # stride of each row in weights grad
94
99
  DB_ptr, # pointer to bias grad, shape (n_cols,)
100
+ stride_db, # stride of each row in bias grad
95
101
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
96
- stride_x, # stride of each row in input
97
- stride_dx, # stride of each row in input grad
98
102
  stride_dy, # stride of each row in output grad
103
+ n_rows,
99
104
  n_cols,
105
+ rows_per_program: tl.constexpr,
100
106
  BLOCK_SIZE: tl.constexpr,
101
- dtype: tl.constexpr,
102
- atomic_dtype: tl.constexpr,
103
107
  ):
104
108
  """
105
109
  References:
106
110
  https://arxiv.org/abs/1607.06450
107
111
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
108
112
  """
109
- row_idx = tl.program_id(0).to(tl.int64)
113
+ row_block_id = tl.program_id(0).to(tl.int64)
114
+ row_start = row_block_id * rows_per_program
115
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
110
116
  cols = tl.arange(0, BLOCK_SIZE)
111
117
  mask = cols < n_cols
112
118
 
119
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
120
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121
+
113
122
  # Pre-load weights once (same optimization as forward pass)
114
123
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
124
  w_f32 = w.to(tl.float32)
116
- n_cols_f32 = n_cols.to(tl.float32)
117
125
 
118
126
  # Calculate pointers for this specific row
119
- row_X_ptr = X_ptr + row_idx * stride_x
120
- row_DX_ptr = DX_ptr + row_idx * stride_dx
121
- row_DY_ptr = DY_ptr + row_idx * stride_dy
122
- row_Mean_ptr = Mean_ptr + row_idx
123
- row_RSTD_ptr = RSTD_ptr + row_idx
124
-
125
- # Load data for this row
126
- x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
127
- dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
128
- mean = tl.load(row_Mean_ptr)
129
- rstd = tl.load(row_RSTD_ptr)
130
-
131
- # Convert to fp32 for numerical stability
132
- x_f32 = x.to(tl.float32)
133
- dy_f32 = dy.to(tl.float32)
134
- mean_f32 = mean.to(tl.float32)
135
- rstd_f32 = rstd.to(tl.float32)
136
-
137
- # Compute backward pass for this row
138
- x_hat = (x_f32 - mean_f32) * rstd_f32
139
- wdy = w_f32 * dy_f32
140
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
- c2 = tl.sum(wdy, axis=0) / n_cols_f32
142
- dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
-
144
- # Store input gradient
145
- tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
146
-
147
- # Accumulate weight and bias gradients using atomic operations
148
- dw = dy_f32 * x_hat
149
- db = dy_f32
150
- tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
151
- tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
127
+ row_X_ptr = X_ptr + row_start * stride_x
128
+ row_DX_ptr = DX_ptr + row_start * stride_dx
129
+ row_DY_ptr = DY_ptr + row_start * stride_dy
130
+ row_Mean_ptr = Mean_ptr + row_start
131
+ row_RSTD_ptr = RSTD_ptr + row_start
132
+
133
+ for _ in range(row_start, row_end):
134
+ # Load data for this row
135
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
136
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
137
+ mean = tl.load(row_Mean_ptr)
138
+ rstd = tl.load(row_RSTD_ptr)
139
+
140
+ # Convert to fp32 for numerical stability
141
+ x_f32 = x.to(tl.float32)
142
+ dy_f32 = dy.to(tl.float32)
143
+ mean_f32 = mean.to(tl.float32)
144
+ rstd_f32 = rstd.to(tl.float32)
145
+
146
+ # Compute backward pass for this row
147
+ x_hat = (x_f32 - mean_f32) * rstd_f32
148
+ wdy = w_f32 * dy_f32
149
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
150
+ c2 = tl.sum(wdy, axis=0) / n_cols
151
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
152
+
153
+ # Store input gradient
154
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
155
+
156
+ # Accumulate weight and bias gradients for this thread block's assigned rows
157
+ dw = dy_f32 * x_hat
158
+ db = dy_f32
159
+ dW_row += dw
160
+ db_row += db
161
+
162
+ row_X_ptr += stride_x
163
+ row_DX_ptr += stride_dx
164
+ row_DY_ptr += stride_dy
165
+ row_Mean_ptr += stride_mean
166
+ row_RSTD_ptr += stride_rstd
167
+
168
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
169
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
152
170
 
153
171
 
154
172
  def layer_norm_forward(X, W, B, eps):
@@ -230,31 +248,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
230
248
  dY = dY.view(-1, dim)
231
249
  n_rows, n_cols = dY.shape
232
250
 
233
- # Allocate gradient tensors
234
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
235
- # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
236
- grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
237
- DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
238
- DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
251
+ sm_count = 1
252
+ if X.device.type == "cuda":
253
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
+ elif X.device.type == "xpu":
255
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256
+
257
+ # fp32 for numerical stability especially.
258
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
259
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
239
260
 
240
261
  # Calculate optimal block size and warp configuration
241
262
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
242
263
  if n_cols > BLOCK_SIZE:
243
264
  raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
265
+ rows_per_program = math.ceil(n_rows / sm_count)
266
+ grid = (sm_count,)
244
267
 
245
- # Determine dtype for triton operations
246
- triton_dtype = (
247
- tl.float32
248
- if X.dtype == torch.float32
249
- else tl.bfloat16
250
- if X.dtype == torch.bfloat16
251
- else tl.float16
252
- if X.dtype == torch.float16
253
- else tl.float32 # fallback
254
- )
255
-
256
- # Use float32 for atomic operations if bfloat16 is not supported
257
- atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
268
+ # Allocate gradient tensors
269
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
258
270
 
259
271
  kernel_args = {"num_warps": num_warps}
260
272
  # XPU-specific optimization
@@ -262,28 +274,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
262
274
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
263
275
 
264
276
  # Launch kernel with one thread block per row for optimal performance
265
- grid = (n_rows,)
266
277
  _layer_norm_backward_kernel[grid](
267
278
  X,
279
+ X.stride(0),
268
280
  W,
269
281
  Mean,
282
+ Mean.stride(0),
270
283
  RSTD,
284
+ RSTD.stride(0),
271
285
  DX,
272
- DW,
273
- DB,
274
- dY,
275
- X.stride(0),
276
286
  DX.stride(0),
287
+ _DW,
288
+ _DW.stride(0),
289
+ _DB,
290
+ _DB.stride(0),
291
+ dY,
277
292
  dY.stride(0),
293
+ n_rows,
278
294
  n_cols,
295
+ rows_per_program=rows_per_program,
279
296
  BLOCK_SIZE=BLOCK_SIZE,
280
- dtype=triton_dtype,
281
- atomic_dtype=atomic_dtype,
282
297
  **kernel_args,
283
298
  )
284
299
 
285
300
  DX = DX.view(*shape)
286
- return DX, DW.to(W.dtype), DB.to(W.dtype)
301
+ DW = _DW.sum(dim=0).to(W.dtype)
302
+ DB = _DB.sum(dim=0).to(B.dtype)
303
+ return DX, DW, DB
287
304
 
288
305
 
289
306
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,386 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+
20
+ @triton.jit
21
+ def _poly_norm_forward_kernel(
22
+ Y_ptr,
23
+ Y_row_stride,
24
+ X_ptr,
25
+ X_row_stride,
26
+ W_ptr, # weight: [3] for [w0, w1, w2]
27
+ B_ptr, # bias: scalar
28
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
29
+ RSTD_row_stride,
30
+ n_cols,
31
+ eps,
32
+ BLOCK_SIZE: tl.constexpr,
33
+ ):
34
+ """
35
+ PolyNorm formula:
36
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
37
+ where norm(u) = u / sqrt(mean(u²) + ε)
38
+
39
+ Reference:
40
+ 1. https://github.com/BryceZhuo/PolyCom/
41
+ 2. https://arxiv.org/pdf/2411.03884
42
+
43
+ Cache rstd values for backward pass
44
+ """
45
+ row_idx = tl.program_id(0).to(tl.int64)
46
+ col_offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = col_offsets < n_cols
48
+
49
+ # Load pointers
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ RSTD_ptr += row_idx * RSTD_row_stride
53
+
54
+ # Load input row
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
56
+
57
+ # Load weights and bias
58
+ w0 = tl.load(W_ptr + 0)
59
+ w1 = tl.load(W_ptr + 1)
60
+ w2 = tl.load(W_ptr + 2)
61
+ b = tl.load(B_ptr)
62
+
63
+ # Compute x³, x², x
64
+ X_pow3 = X_row * X_row * X_row
65
+ X_pow2 = X_row * X_row
66
+ X_pow1 = X_row
67
+
68
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
69
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
70
+ rstd_3 = rsqrt(mean_square_3 + eps)
71
+ norm_x3 = X_pow3 * rstd_3
72
+
73
+ # Compute norm(x²)
74
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
75
+ rstd_2 = rsqrt(mean_square_2 + eps)
76
+ norm_x2 = X_pow2 * rstd_2
77
+
78
+ # Compute norm(x)
79
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
80
+ rstd_1 = rsqrt(mean_square_1 + eps)
81
+ norm_x1 = X_pow1 * rstd_1
82
+
83
+ # Cache rstd values for backward
84
+ tl.store(RSTD_ptr + 0, rstd_3)
85
+ tl.store(RSTD_ptr + 1, rstd_2)
86
+ tl.store(RSTD_ptr + 2, rstd_1)
87
+
88
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
89
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
90
+
91
+ # Store output
92
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def _poly_norm_backward_kernel(
97
+ dY_ptr,
98
+ dY_row_stride,
99
+ dX_ptr,
100
+ dX_row_stride,
101
+ X_ptr,
102
+ X_row_stride,
103
+ W_ptr,
104
+ RSTD_ptr,
105
+ RSTD_row_stride,
106
+ dW_ptr, # shape: (n_programs, 3)
107
+ dW_row_stride,
108
+ dB_ptr, # shape: (n_programs,)
109
+ n_rows,
110
+ n_cols,
111
+ rows_per_program: tl.constexpr,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ ):
114
+ """
115
+ PolyNorm Backward Kernel Gradient:
116
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
117
+
118
+ where:
119
+ - D_p = RMS(x^p) = 1/rstd_p
120
+ - S_p = sum(grad * x^p) over the row
121
+ - d = n_cols
122
+ - p ∈ {3, 2, 1}
123
+ """
124
+ row_block_id = tl.program_id(0).to(tl.int64)
125
+ row_start = row_block_id * rows_per_program
126
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
127
+ col_offsets = tl.arange(0, BLOCK_SIZE)
128
+ mask = col_offsets < n_cols
129
+
130
+ # Initialize accumulators for weight and bias gradients (scalars)
131
+ dW0_acc = 0.0
132
+ dW1_acc = 0.0
133
+ dW2_acc = 0.0
134
+ dB_acc = 0.0
135
+
136
+ # Load weights
137
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
138
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
139
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
140
+
141
+ dY_ptr += row_start * dY_row_stride
142
+ dX_ptr += row_start * dX_row_stride
143
+ X_ptr += row_start * X_row_stride
144
+ RSTD_ptr += row_start * RSTD_row_stride
145
+
146
+ for _ in range(row_start, row_end):
147
+ # Load input and gradient
148
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+
151
+ # Load cached rstd values
152
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
155
+
156
+ # Compute powers
157
+ X_pow3 = X_row * X_row * X_row
158
+ X_pow2 = X_row * X_row
159
+ X_pow1 = X_row
160
+
161
+ # Accumulate bias gradient: dB = sum(dY)
162
+ dB_acc += tl.sum(dY_row, axis=0)
163
+
164
+ # Compute gradient w.r.t. input using closed-form formula
165
+ # For p=3: ∂L/∂x from w0 * norm(x³)
166
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
167
+ grad_x_3 = w0 * (
168
+ 3.0 * X_pow2 * rstd_3 * dY_row
169
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
170
+ )
171
+
172
+ # For p=2: ∂L/∂x from w1 * norm(x²)
173
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
174
+ grad_x_2 = w1 * (
175
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
176
+ )
177
+
178
+ # For p=1: ∂L/∂x from w2 * norm(x)
179
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
180
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
181
+
182
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
183
+ dW0_acc += rstd_3 * S_3
184
+ dW1_acc += rstd_2 * S_2
185
+ dW2_acc += rstd_1 * S_1
186
+
187
+ # Total gradient
188
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
+
190
+ # Store gradient
191
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
+
193
+ # Update pointers
194
+ dY_ptr += dY_row_stride
195
+ dX_ptr += dX_row_stride
196
+ X_ptr += X_row_stride
197
+ RSTD_ptr += RSTD_row_stride
198
+
199
+ # Store accumulated gradients (scalars)
200
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
201
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
203
+ tl.store(dB_ptr + row_block_id, dB_acc)
204
+
205
+
206
+ def poly_norm_forward(X, W, B, eps=1e-6):
207
+ """
208
+ PolyNorm Forward Pass
209
+
210
+ Args:
211
+ X: input tensor of shape (*, H) where H is hidden dimension
212
+ W: weight tensor of shape (3,) for [w0, w1, w2]
213
+ B: bias scalar tensor
214
+ eps: epsilon for numerical stability
215
+
216
+ Returns:
217
+ Y: output tensor of same shape as X
218
+ X: reshaped input (for backward)
219
+ RSTD: cached rstd values (for backward)
220
+ BLOCK_SIZE: block size used
221
+ num_warps: number of warps used
222
+ """
223
+ shape = X.shape
224
+ dim = shape[-1]
225
+ X = X.view(-1, dim)
226
+ n_rows, n_cols = X.shape
227
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
228
+
229
+ # RSTD is to cache rstd for each row
230
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
231
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
232
+
233
+ # Check constraints
234
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
235
+ assert B.numel() == 1, "Bias must be a scalar"
236
+
237
+ # XPU-specific optimization
238
+ kernel_args = {}
239
+ if X.device.type == "xpu":
240
+ kernel_args["grf_mode"] = "large"
241
+
242
+ # Launch kernel
243
+ _poly_norm_forward_kernel[(n_rows,)](
244
+ Y,
245
+ Y.stride(0),
246
+ X,
247
+ X.stride(0),
248
+ W,
249
+ B,
250
+ RSTD,
251
+ RSTD.stride(0),
252
+ n_cols,
253
+ eps,
254
+ BLOCK_SIZE=BLOCK_SIZE,
255
+ num_warps=num_warps,
256
+ **kernel_args,
257
+ )
258
+
259
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
260
+
261
+
262
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
263
+ """
264
+ PolyNorm Backward Pass
265
+
266
+ Args:
267
+ dY: gradient of output
268
+ X: input tensor (already reshaped to 2D)
269
+ W: weight tensor
270
+ RSTD: cached rstd values from forward
271
+ BLOCK_SIZE: block size from forward
272
+ num_warps: number of warps from forward
273
+ in_place: whether to in-place modify dY to store dX (saves memory)
274
+
275
+ Returns:
276
+ dX: gradient w.r.t. input
277
+ dW: gradient w.r.t. weight
278
+ dB: gradient w.r.t. bias
279
+ """
280
+ shape = dY.shape
281
+ dim = shape[-1]
282
+ dY = dY.view(-1, dim)
283
+ n_rows, n_cols = dY.shape
284
+
285
+ # Get number of SMs for parallelization
286
+ import math
287
+
288
+ sm_count = 1
289
+ if X.device.type == "cuda":
290
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
+ elif X.device.type == "xpu":
292
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
293
+
294
+ # Allocate or reuse gradients
295
+ if in_place is True:
296
+ dX = dY
297
+ else:
298
+ dX = torch.zeros_like(dY)
299
+
300
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
301
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
302
+
303
+ rows_per_program = math.ceil(n_rows / sm_count)
304
+ grid = (sm_count,)
305
+
306
+ # XPU-specific optimization
307
+ kernel_args = {}
308
+ if X.device.type == "xpu":
309
+ kernel_args["grf_mode"] = "large"
310
+
311
+ # Launch backward kernel
312
+ _poly_norm_backward_kernel[grid](
313
+ dY,
314
+ dY.stride(0),
315
+ dX,
316
+ dX.stride(0),
317
+ X,
318
+ X.stride(0),
319
+ W,
320
+ RSTD,
321
+ RSTD.stride(0),
322
+ _dW,
323
+ _dW.stride(0),
324
+ _dB,
325
+ n_rows,
326
+ n_cols,
327
+ rows_per_program,
328
+ BLOCK_SIZE=BLOCK_SIZE,
329
+ num_warps=num_warps,
330
+ **kernel_args,
331
+ )
332
+
333
+ # Reduce gradients across SMs
334
+ dX = dX.view(*shape)
335
+ dW = _dW.sum(dim=0).to(W.dtype)
336
+ dB = _dB.sum().to(W.dtype)
337
+
338
+ return dX, dW, dB
339
+
340
+
341
+ class LigerPolyNormFunction(torch.autograd.Function):
342
+ """
343
+ PolyNorm Function with forward and backward pass
344
+
345
+ PolyNorm formula:
346
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
347
+ where norm(u) = u / sqrt(mean(u²) + ε)
348
+
349
+ Backward uses closed-form gradient:
350
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
351
+ """
352
+
353
+ @staticmethod
354
+ @ensure_contiguous
355
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
356
+ """
357
+ Args:
358
+ X: input tensor of shape (B, T, H) or (BxT, H)
359
+ W: weight tensor of shape (3,) for [w0, w1, w2]
360
+ B: bias scalar
361
+ eps: epsilon for numerical stability
362
+ in_place: whether to in-place modify grad_output in backward (saves memory)
363
+
364
+ Returns:
365
+ Y: output tensor of same shape as X
366
+ """
367
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
368
+ ctx.BLOCK_SIZE = BLOCK_SIZE
369
+ ctx.num_warps = num_warps
370
+ ctx.in_place = in_place
371
+ ctx.save_for_backward(X, W, RSTD)
372
+ return Y
373
+
374
+ @staticmethod
375
+ @ensure_contiguous
376
+ def backward(ctx, grad_output):
377
+ """
378
+ Args:
379
+ grad_output: gradient of output
380
+
381
+ Returns:
382
+ dX, dW, dB: gradients w.r.t. X, W, B
383
+ """
384
+ X, W, RSTD = ctx.saved_tensors
385
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
386
+ return dX, dW, dB, None, None