liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  2. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  3. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  4. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  5. liger_kernel/ops/cross_entropy.py +59 -53
  6. liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
  7. liger_kernel/ops/layer_norm.py +4 -6
  8. liger_kernel/ops/llama4_rope.py +225 -0
  9. liger_kernel/ops/poly_norm.py +386 -0
  10. liger_kernel/transformers/__init__.py +32 -0
  11. liger_kernel/transformers/experimental/__init__.py +5 -0
  12. liger_kernel/transformers/functional.py +9 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
  14. liger_kernel/transformers/llama4_rope.py +93 -0
  15. liger_kernel/transformers/model/falcon_h1.py +108 -0
  16. liger_kernel/transformers/model/gemma.py +2 -1
  17. liger_kernel/transformers/model/gemma2.py +8 -2
  18. liger_kernel/transformers/model/gemma3.py +27 -2
  19. liger_kernel/transformers/model/glm4.py +2 -1
  20. liger_kernel/transformers/model/glm4v.py +151 -0
  21. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  22. liger_kernel/transformers/model/internvl.py +150 -0
  23. liger_kernel/transformers/model/llama.py +2 -1
  24. liger_kernel/transformers/model/llama4.py +2 -1
  25. liger_kernel/transformers/model/llava.py +6 -2
  26. liger_kernel/transformers/model/loss_utils.py +3 -0
  27. liger_kernel/transformers/model/mistral.py +2 -1
  28. liger_kernel/transformers/model/mixtral.py +8 -2
  29. liger_kernel/transformers/model/mllama.py +6 -3
  30. liger_kernel/transformers/model/olmo2.py +2 -1
  31. liger_kernel/transformers/model/paligemma.py +19 -0
  32. liger_kernel/transformers/model/phi3.py +10 -160
  33. liger_kernel/transformers/model/qwen2.py +2 -1
  34. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  35. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  36. liger_kernel/transformers/model/qwen3.py +2 -1
  37. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  38. liger_kernel/transformers/model/qwen3_next.py +134 -0
  39. liger_kernel/transformers/model/smollm3.py +2 -1
  40. liger_kernel/transformers/model/smolvlm.py +158 -0
  41. liger_kernel/transformers/monkey_patch.py +552 -23
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +42 -0
  44. liger_kernel/transformers/rms_norm.py +7 -0
  45. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
  46. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
  47. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  48. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  49. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  50. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None
@@ -0,0 +1,386 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+
20
+ @triton.jit
21
+ def _poly_norm_forward_kernel(
22
+ Y_ptr,
23
+ Y_row_stride,
24
+ X_ptr,
25
+ X_row_stride,
26
+ W_ptr, # weight: [3] for [w0, w1, w2]
27
+ B_ptr, # bias: scalar
28
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
29
+ RSTD_row_stride,
30
+ n_cols,
31
+ eps,
32
+ BLOCK_SIZE: tl.constexpr,
33
+ ):
34
+ """
35
+ PolyNorm formula:
36
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
37
+ where norm(u) = u / sqrt(mean(u²) + ε)
38
+
39
+ Reference:
40
+ 1. https://github.com/BryceZhuo/PolyCom/
41
+ 2. https://arxiv.org/pdf/2411.03884
42
+
43
+ Cache rstd values for backward pass
44
+ """
45
+ row_idx = tl.program_id(0).to(tl.int64)
46
+ col_offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = col_offsets < n_cols
48
+
49
+ # Load pointers
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ RSTD_ptr += row_idx * RSTD_row_stride
53
+
54
+ # Load input row
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
56
+
57
+ # Load weights and bias
58
+ w0 = tl.load(W_ptr + 0)
59
+ w1 = tl.load(W_ptr + 1)
60
+ w2 = tl.load(W_ptr + 2)
61
+ b = tl.load(B_ptr)
62
+
63
+ # Compute x³, x², x
64
+ X_pow3 = X_row * X_row * X_row
65
+ X_pow2 = X_row * X_row
66
+ X_pow1 = X_row
67
+
68
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
69
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
70
+ rstd_3 = rsqrt(mean_square_3 + eps)
71
+ norm_x3 = X_pow3 * rstd_3
72
+
73
+ # Compute norm(x²)
74
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
75
+ rstd_2 = rsqrt(mean_square_2 + eps)
76
+ norm_x2 = X_pow2 * rstd_2
77
+
78
+ # Compute norm(x)
79
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
80
+ rstd_1 = rsqrt(mean_square_1 + eps)
81
+ norm_x1 = X_pow1 * rstd_1
82
+
83
+ # Cache rstd values for backward
84
+ tl.store(RSTD_ptr + 0, rstd_3)
85
+ tl.store(RSTD_ptr + 1, rstd_2)
86
+ tl.store(RSTD_ptr + 2, rstd_1)
87
+
88
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
89
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
90
+
91
+ # Store output
92
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def _poly_norm_backward_kernel(
97
+ dY_ptr,
98
+ dY_row_stride,
99
+ dX_ptr,
100
+ dX_row_stride,
101
+ X_ptr,
102
+ X_row_stride,
103
+ W_ptr,
104
+ RSTD_ptr,
105
+ RSTD_row_stride,
106
+ dW_ptr, # shape: (n_programs, 3)
107
+ dW_row_stride,
108
+ dB_ptr, # shape: (n_programs,)
109
+ n_rows,
110
+ n_cols,
111
+ rows_per_program: tl.constexpr,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ ):
114
+ """
115
+ PolyNorm Backward Kernel Gradient:
116
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
117
+
118
+ where:
119
+ - D_p = RMS(x^p) = 1/rstd_p
120
+ - S_p = sum(grad * x^p) over the row
121
+ - d = n_cols
122
+ - p ∈ {3, 2, 1}
123
+ """
124
+ row_block_id = tl.program_id(0).to(tl.int64)
125
+ row_start = row_block_id * rows_per_program
126
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
127
+ col_offsets = tl.arange(0, BLOCK_SIZE)
128
+ mask = col_offsets < n_cols
129
+
130
+ # Initialize accumulators for weight and bias gradients (scalars)
131
+ dW0_acc = 0.0
132
+ dW1_acc = 0.0
133
+ dW2_acc = 0.0
134
+ dB_acc = 0.0
135
+
136
+ # Load weights
137
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
138
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
139
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
140
+
141
+ dY_ptr += row_start * dY_row_stride
142
+ dX_ptr += row_start * dX_row_stride
143
+ X_ptr += row_start * X_row_stride
144
+ RSTD_ptr += row_start * RSTD_row_stride
145
+
146
+ for _ in range(row_start, row_end):
147
+ # Load input and gradient
148
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+
151
+ # Load cached rstd values
152
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
155
+
156
+ # Compute powers
157
+ X_pow3 = X_row * X_row * X_row
158
+ X_pow2 = X_row * X_row
159
+ X_pow1 = X_row
160
+
161
+ # Accumulate bias gradient: dB = sum(dY)
162
+ dB_acc += tl.sum(dY_row, axis=0)
163
+
164
+ # Compute gradient w.r.t. input using closed-form formula
165
+ # For p=3: ∂L/∂x from w0 * norm(x³)
166
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
167
+ grad_x_3 = w0 * (
168
+ 3.0 * X_pow2 * rstd_3 * dY_row
169
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
170
+ )
171
+
172
+ # For p=2: ∂L/∂x from w1 * norm(x²)
173
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
174
+ grad_x_2 = w1 * (
175
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
176
+ )
177
+
178
+ # For p=1: ∂L/∂x from w2 * norm(x)
179
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
180
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
181
+
182
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
183
+ dW0_acc += rstd_3 * S_3
184
+ dW1_acc += rstd_2 * S_2
185
+ dW2_acc += rstd_1 * S_1
186
+
187
+ # Total gradient
188
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
+
190
+ # Store gradient
191
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
+
193
+ # Update pointers
194
+ dY_ptr += dY_row_stride
195
+ dX_ptr += dX_row_stride
196
+ X_ptr += X_row_stride
197
+ RSTD_ptr += RSTD_row_stride
198
+
199
+ # Store accumulated gradients (scalars)
200
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
201
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
203
+ tl.store(dB_ptr + row_block_id, dB_acc)
204
+
205
+
206
+ def poly_norm_forward(X, W, B, eps=1e-6):
207
+ """
208
+ PolyNorm Forward Pass
209
+
210
+ Args:
211
+ X: input tensor of shape (*, H) where H is hidden dimension
212
+ W: weight tensor of shape (3,) for [w0, w1, w2]
213
+ B: bias scalar tensor
214
+ eps: epsilon for numerical stability
215
+
216
+ Returns:
217
+ Y: output tensor of same shape as X
218
+ X: reshaped input (for backward)
219
+ RSTD: cached rstd values (for backward)
220
+ BLOCK_SIZE: block size used
221
+ num_warps: number of warps used
222
+ """
223
+ shape = X.shape
224
+ dim = shape[-1]
225
+ X = X.view(-1, dim)
226
+ n_rows, n_cols = X.shape
227
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
228
+
229
+ # RSTD is to cache rstd for each row
230
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
231
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
232
+
233
+ # Check constraints
234
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
235
+ assert B.numel() == 1, "Bias must be a scalar"
236
+
237
+ # XPU-specific optimization
238
+ kernel_args = {}
239
+ if X.device.type == "xpu":
240
+ kernel_args["grf_mode"] = "large"
241
+
242
+ # Launch kernel
243
+ _poly_norm_forward_kernel[(n_rows,)](
244
+ Y,
245
+ Y.stride(0),
246
+ X,
247
+ X.stride(0),
248
+ W,
249
+ B,
250
+ RSTD,
251
+ RSTD.stride(0),
252
+ n_cols,
253
+ eps,
254
+ BLOCK_SIZE=BLOCK_SIZE,
255
+ num_warps=num_warps,
256
+ **kernel_args,
257
+ )
258
+
259
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
260
+
261
+
262
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
263
+ """
264
+ PolyNorm Backward Pass
265
+
266
+ Args:
267
+ dY: gradient of output
268
+ X: input tensor (already reshaped to 2D)
269
+ W: weight tensor
270
+ RSTD: cached rstd values from forward
271
+ BLOCK_SIZE: block size from forward
272
+ num_warps: number of warps from forward
273
+ in_place: whether to in-place modify dY to store dX (saves memory)
274
+
275
+ Returns:
276
+ dX: gradient w.r.t. input
277
+ dW: gradient w.r.t. weight
278
+ dB: gradient w.r.t. bias
279
+ """
280
+ shape = dY.shape
281
+ dim = shape[-1]
282
+ dY = dY.view(-1, dim)
283
+ n_rows, n_cols = dY.shape
284
+
285
+ # Get number of SMs for parallelization
286
+ import math
287
+
288
+ sm_count = 1
289
+ if X.device.type == "cuda":
290
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
+ elif X.device.type == "xpu":
292
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
293
+
294
+ # Allocate or reuse gradients
295
+ if in_place is True:
296
+ dX = dY
297
+ else:
298
+ dX = torch.zeros_like(dY)
299
+
300
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
301
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
302
+
303
+ rows_per_program = math.ceil(n_rows / sm_count)
304
+ grid = (sm_count,)
305
+
306
+ # XPU-specific optimization
307
+ kernel_args = {}
308
+ if X.device.type == "xpu":
309
+ kernel_args["grf_mode"] = "large"
310
+
311
+ # Launch backward kernel
312
+ _poly_norm_backward_kernel[grid](
313
+ dY,
314
+ dY.stride(0),
315
+ dX,
316
+ dX.stride(0),
317
+ X,
318
+ X.stride(0),
319
+ W,
320
+ RSTD,
321
+ RSTD.stride(0),
322
+ _dW,
323
+ _dW.stride(0),
324
+ _dB,
325
+ n_rows,
326
+ n_cols,
327
+ rows_per_program,
328
+ BLOCK_SIZE=BLOCK_SIZE,
329
+ num_warps=num_warps,
330
+ **kernel_args,
331
+ )
332
+
333
+ # Reduce gradients across SMs
334
+ dX = dX.view(*shape)
335
+ dW = _dW.sum(dim=0).to(W.dtype)
336
+ dB = _dB.sum().to(W.dtype)
337
+
338
+ return dX, dW, dB
339
+
340
+
341
+ class LigerPolyNormFunction(torch.autograd.Function):
342
+ """
343
+ PolyNorm Function with forward and backward pass
344
+
345
+ PolyNorm formula:
346
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
347
+ where norm(u) = u / sqrt(mean(u²) + ε)
348
+
349
+ Backward uses closed-form gradient:
350
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
351
+ """
352
+
353
+ @staticmethod
354
+ @ensure_contiguous
355
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
356
+ """
357
+ Args:
358
+ X: input tensor of shape (B, T, H) or (BxT, H)
359
+ W: weight tensor of shape (3,) for [w0, w1, w2]
360
+ B: bias scalar
361
+ eps: epsilon for numerical stability
362
+ in_place: whether to in-place modify grad_output in backward (saves memory)
363
+
364
+ Returns:
365
+ Y: output tensor of same shape as X
366
+ """
367
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
368
+ ctx.BLOCK_SIZE = BLOCK_SIZE
369
+ ctx.num_warps = num_warps
370
+ ctx.in_place = in_place
371
+ ctx.save_for_backward(X, W, RSTD)
372
+ return Y
373
+
374
+ @staticmethod
375
+ @ensure_contiguous
376
+ def backward(ctx, grad_output):
377
+ """
378
+ Args:
379
+ grad_output: gradient of output
380
+
381
+ Returns:
382
+ dX, dW, dB: gradients w.r.t. X, W, B
383
+ """
384
+ X, W, RSTD = ctx.saved_tensors
385
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
386
+ return dX, dW, dB, None, None