liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.sparsemax import _sparsemax_backward
10
+ from liger_kernel.ops.sparsemax import _sparsemax_forward
11
+ from liger_kernel.ops.utils import calculate_settings
12
+ from liger_kernel.ops.utils import ensure_contiguous
13
+
14
+
15
+ @triton.jit
16
+ def _mask_fwd_kernel(
17
+ scores_ptr,
18
+ out_ptr,
19
+ stride_b,
20
+ stride_m,
21
+ stride_n,
22
+ L,
23
+ mask_val: tl.constexpr,
24
+ BLOCK: tl.constexpr,
25
+ num_warps: tl.constexpr,
26
+ ):
27
+ row_block = tl.program_id(0)
28
+ col_block = tl.program_id(1)
29
+ batch_id = tl.program_id(2)
30
+
31
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
32
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
33
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
34
+
35
+ base = scores_ptr + batch_id * stride_b
36
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
37
+ future = col_idx[None, :] > row_idx[:, None]
38
+ mask_load = in_bounds & ~future
39
+ out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
40
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
41
+
42
+
43
+ @triton.jit
44
+ def _mask_bwd_kernel(
45
+ grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
46
+ ):
47
+ row_block = tl.program_id(0)
48
+ col_block = tl.program_id(1)
49
+ batch_id = tl.program_id(2)
50
+
51
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
52
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
53
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
54
+
55
+ base = grad_in_ptr + batch_id * stride_b
56
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
57
+ grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
58
+
59
+ future = col_idx[None, :] > row_idx[:, None]
60
+ zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
61
+ out = tl.where(future, zero, grad_vals)
62
+
63
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
64
+
65
+
66
+ def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
67
+ *batch, L, _ = scores.shape
68
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
69
+ scores_f = scores.view(N, L, L)
70
+ out = torch.empty_like(scores_f)
71
+
72
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
73
+ BLOCK_SIZE, num_warps = calculate_settings(L)
74
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
75
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
76
+ return out.view(*batch, L, L)
77
+
78
+
79
+ def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
80
+ *batch, L, _ = grad.shape
81
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
82
+ grad_f = grad.view(N, L, L)
83
+ out = torch.empty_like(grad_f)
84
+
85
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
86
+ BLOCK_SIZE, num_warps = calculate_settings(L)
87
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
88
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
89
+ return out.view(*batch, L, L)
90
+
91
+
92
+ def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
93
+ *batch, L, _ = scores.shape
94
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
95
+ scores_f = scores.view(N, L, L)
96
+ out = torch.empty_like(scores_f)
97
+
98
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
99
+ BLOCK_SIZE, num_warps = calculate_settings(L)
100
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
101
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
102
+ return out.view(*batch, L, L)
103
+
104
+
105
+ def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
106
+ *batch, L, _ = grad.shape
107
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
108
+ grad_f = grad.view(N, L, L)
109
+ out = torch.empty_like(grad_f)
110
+
111
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
112
+ BLOCK_SIZE, num_warps = calculate_settings(L)
113
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
114
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
115
+ return out.view(*batch, L, L)
116
+
117
+
118
+ class LigerMultiTokenAttentionFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ @ensure_contiguous
121
+ def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
122
+ scores_inf = _mask_inf_forward(scores)
123
+
124
+ out_flat_sparse = None
125
+ activation_output = None
126
+
127
+ ctx.sparse = sparse
128
+
129
+ if sparse:
130
+ if scores_inf.dtype != torch.float32:
131
+ raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
132
+ probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
133
+ activation_output = probs_sparse
134
+ ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
135
+ ctx.out_flat_sparse_saved = True
136
+ else:
137
+ probs_softmax, _, _, _ = _softmax_forward(scores_inf)
138
+ activation_output = probs_softmax
139
+ ctx.save_for_backward(scores_inf, activation_output, weight, bias)
140
+ ctx.out_flat_sparse_saved = False
141
+
142
+ out_conv = F.conv2d(
143
+ activation_output,
144
+ weight,
145
+ bias,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ groups=groups,
150
+ )
151
+
152
+ out = _mask_zero_forward(out_conv)
153
+
154
+ ctx.stride = _pair(stride)
155
+ ctx.padding = _pair(padding)
156
+ ctx.dilation = _pair(dilation)
157
+ ctx.groups = groups
158
+ ctx.dim = -1
159
+
160
+ return out
161
+
162
+ @staticmethod
163
+ @ensure_contiguous
164
+ def backward(ctx, grad_out):
165
+ if ctx.out_flat_sparse_saved:
166
+ scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
167
+ else:
168
+ scores_inf, activation_output, weight, bias = ctx.saved_tensors
169
+ out_flat_sparse = None
170
+
171
+ use_sparsemax = ctx.sparse
172
+ dim = ctx.dim
173
+ stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
174
+
175
+ grad_conv = _mask_zero_backward(grad_out)
176
+
177
+ grad_probs = F.conv_transpose2d(
178
+ grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
179
+ )
180
+
181
+ grad_weight = torch.nn.grad.conv2d_weight(
182
+ input=activation_output,
183
+ weight_size=weight.shape,
184
+ grad_output=grad_conv,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ )
190
+ grad_bias = None
191
+ if bias is not None:
192
+ grad_bias = grad_conv.sum(dim=(0, 2, 3))
193
+
194
+ grad_scores_inf = None
195
+ if use_sparsemax:
196
+ if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
197
+ raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
198
+ grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
199
+ else:
200
+ grad_probs_cont = grad_probs
201
+ probs_cont = activation_output
202
+ dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
203
+ grad_scores_inf = probs_cont * (grad_probs_cont - dot)
204
+
205
+ grad_scores = _mask_inf_backward(grad_scores_inf)
206
+
207
+ return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
@@ -0,0 +1,390 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+
22
+ @triton.jit
23
+ def _poly_norm_forward_kernel(
24
+ Y_ptr,
25
+ Y_row_stride,
26
+ X_ptr,
27
+ X_row_stride,
28
+ W_ptr, # weight: [3] for [w0, w1, w2]
29
+ B_ptr, # bias: scalar
30
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
31
+ RSTD_row_stride,
32
+ n_cols,
33
+ eps,
34
+ BLOCK_SIZE: tl.constexpr,
35
+ ):
36
+ """
37
+ PolyNorm formula:
38
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
39
+ where norm(u) = u / sqrt(mean(u²) + ε)
40
+
41
+ Reference:
42
+ 1. https://github.com/BryceZhuo/PolyCom/
43
+ 2. https://arxiv.org/pdf/2411.03884
44
+
45
+ Cache rstd values for backward pass
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Load pointers
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ RSTD_ptr += row_idx * RSTD_row_stride
55
+
56
+ # Load input row
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
58
+
59
+ # Load weights and bias
60
+ w0 = tl.load(W_ptr + 0)
61
+ w1 = tl.load(W_ptr + 1)
62
+ w2 = tl.load(W_ptr + 2)
63
+ b = tl.load(B_ptr)
64
+
65
+ # Compute x³, x², x
66
+ X_pow3 = X_row * X_row * X_row
67
+ X_pow2 = X_row * X_row
68
+ X_pow1 = X_row
69
+
70
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
71
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
72
+ rstd_3 = rsqrt(mean_square_3 + eps)
73
+ norm_x3 = X_pow3 * rstd_3
74
+
75
+ # Compute norm(x²)
76
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
77
+ rstd_2 = rsqrt(mean_square_2 + eps)
78
+ norm_x2 = X_pow2 * rstd_2
79
+
80
+ # Compute norm(x)
81
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
82
+ rstd_1 = rsqrt(mean_square_1 + eps)
83
+ norm_x1 = X_pow1 * rstd_1
84
+
85
+ # Cache rstd values for backward
86
+ tl.store(RSTD_ptr + 0, rstd_3)
87
+ tl.store(RSTD_ptr + 1, rstd_2)
88
+ tl.store(RSTD_ptr + 2, rstd_1)
89
+
90
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
91
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
92
+
93
+ # Store output
94
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
95
+
96
+
97
+ @triton.jit
98
+ def _poly_norm_backward_kernel(
99
+ dY_ptr,
100
+ dY_row_stride,
101
+ dX_ptr,
102
+ dX_row_stride,
103
+ X_ptr,
104
+ X_row_stride,
105
+ W_ptr,
106
+ RSTD_ptr,
107
+ RSTD_row_stride,
108
+ dW_ptr, # shape: (n_programs, 3)
109
+ dW_row_stride,
110
+ dB_ptr, # shape: (n_programs,)
111
+ n_rows,
112
+ n_cols,
113
+ rows_per_program: tl.constexpr,
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ """
117
+ PolyNorm Backward Kernel Gradient:
118
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
119
+
120
+ where:
121
+ - D_p = RMS(x^p) = 1/rstd_p
122
+ - S_p = sum(grad * x^p) over the row
123
+ - d = n_cols
124
+ - p ∈ {3, 2, 1}
125
+ """
126
+ row_block_id = tl.program_id(0).to(tl.int64)
127
+ row_start = row_block_id * rows_per_program
128
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
+ col_offsets = tl.arange(0, BLOCK_SIZE)
130
+ mask = col_offsets < n_cols
131
+
132
+ # Initialize accumulators for weight and bias gradients (scalars)
133
+ dW0_acc = 0.0
134
+ dW1_acc = 0.0
135
+ dW2_acc = 0.0
136
+ dB_acc = 0.0
137
+
138
+ # Load weights
139
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
140
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
141
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
142
+
143
+ dY_ptr += row_start * dY_row_stride
144
+ dX_ptr += row_start * dX_row_stride
145
+ X_ptr += row_start * X_row_stride
146
+ RSTD_ptr += row_start * RSTD_row_stride
147
+
148
+ for _ in range(row_start, row_end):
149
+ # Load input and gradient
150
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
+
153
+ # Load cached rstd values
154
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
157
+
158
+ # Compute powers
159
+ X_pow3 = X_row * X_row * X_row
160
+ X_pow2 = X_row * X_row
161
+ X_pow1 = X_row
162
+
163
+ # Accumulate bias gradient: dB = sum(dY)
164
+ dB_acc += tl.sum(dY_row, axis=0)
165
+
166
+ # Compute gradient w.r.t. input using closed-form formula
167
+ # For p=3: ∂L/∂x from w0 * norm(x³)
168
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
169
+ grad_x_3 = w0 * (
170
+ 3.0 * X_pow2 * rstd_3 * dY_row
171
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
172
+ )
173
+
174
+ # For p=2: ∂L/∂x from w1 * norm(x²)
175
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
176
+ grad_x_2 = w1 * (
177
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
178
+ )
179
+
180
+ # For p=1: ∂L/∂x from w2 * norm(x)
181
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
182
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
183
+
184
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
185
+ dW0_acc += rstd_3 * S_3
186
+ dW1_acc += rstd_2 * S_2
187
+ dW2_acc += rstd_1 * S_1
188
+
189
+ # Total gradient
190
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
+
192
+ # Store gradient
193
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
+
195
+ # Update pointers
196
+ dY_ptr += dY_row_stride
197
+ dX_ptr += dX_row_stride
198
+ X_ptr += X_row_stride
199
+ RSTD_ptr += RSTD_row_stride
200
+
201
+ # Store accumulated gradients (scalars)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
203
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
204
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
205
+ tl.store(dB_ptr + row_block_id, dB_acc)
206
+
207
+
208
+ def poly_norm_forward(X, W, B, eps=1e-6):
209
+ """
210
+ PolyNorm Forward Pass
211
+
212
+ Args:
213
+ X: input tensor of shape (*, H) where H is hidden dimension
214
+ W: weight tensor of shape (3,) for [w0, w1, w2]
215
+ B: bias scalar tensor
216
+ eps: epsilon for numerical stability
217
+
218
+ Returns:
219
+ Y: output tensor of same shape as X
220
+ X: reshaped input (for backward)
221
+ RSTD: cached rstd values (for backward)
222
+ BLOCK_SIZE: block size used
223
+ num_warps: number of warps used
224
+ """
225
+ shape = X.shape
226
+ dim = shape[-1]
227
+ X = X.view(-1, dim)
228
+ n_rows, n_cols = X.shape
229
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
230
+
231
+ # RSTD is to cache rstd for each row
232
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
234
+
235
+ # Check constraints
236
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
237
+ assert B.numel() == 1, "Bias must be a scalar"
238
+
239
+ # XPU-specific optimization
240
+ kernel_args = {}
241
+ if X.device.type == "xpu":
242
+ kernel_args["grf_mode"] = "large"
243
+
244
+ # Launch kernel
245
+ _poly_norm_forward_kernel[(n_rows,)](
246
+ Y,
247
+ Y.stride(0),
248
+ X,
249
+ X.stride(0),
250
+ W,
251
+ B,
252
+ RSTD,
253
+ RSTD.stride(0),
254
+ n_cols,
255
+ eps,
256
+ BLOCK_SIZE=BLOCK_SIZE,
257
+ num_warps=num_warps,
258
+ **kernel_args,
259
+ )
260
+
261
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
262
+
263
+
264
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
265
+ """
266
+ PolyNorm Backward Pass
267
+
268
+ Args:
269
+ dY: gradient of output
270
+ X: input tensor (already reshaped to 2D)
271
+ W: weight tensor
272
+ RSTD: cached rstd values from forward
273
+ BLOCK_SIZE: block size from forward
274
+ num_warps: number of warps from forward
275
+ in_place: whether to in-place modify dY to store dX (saves memory)
276
+
277
+ Returns:
278
+ dX: gradient w.r.t. input
279
+ dW: gradient w.r.t. weight
280
+ dB: gradient w.r.t. bias
281
+ """
282
+ shape = dY.shape
283
+ dim = shape[-1]
284
+ dY = dY.view(-1, dim)
285
+ n_rows, n_cols = dY.shape
286
+
287
+ # Get number of SMs for parallelization
288
+ import math
289
+
290
+ sm_count = 1
291
+ if X.device.type == "cuda":
292
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
293
+ elif X.device.type == "xpu":
294
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295
+ elif X.device.type == "npu":
296
+ sm_count = get_npu_multi_processor_count()
297
+
298
+ # Allocate or reuse gradients
299
+ if in_place is True:
300
+ dX = dY
301
+ else:
302
+ dX = torch.zeros_like(dY)
303
+
304
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
305
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
306
+
307
+ rows_per_program = math.ceil(n_rows / sm_count)
308
+ grid = (sm_count,)
309
+
310
+ # XPU-specific optimization
311
+ kernel_args = {}
312
+ if X.device.type == "xpu":
313
+ kernel_args["grf_mode"] = "large"
314
+
315
+ # Launch backward kernel
316
+ _poly_norm_backward_kernel[grid](
317
+ dY,
318
+ dY.stride(0),
319
+ dX,
320
+ dX.stride(0),
321
+ X,
322
+ X.stride(0),
323
+ W,
324
+ RSTD,
325
+ RSTD.stride(0),
326
+ _dW,
327
+ _dW.stride(0),
328
+ _dB,
329
+ n_rows,
330
+ n_cols,
331
+ rows_per_program,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ num_warps=num_warps,
334
+ **kernel_args,
335
+ )
336
+
337
+ # Reduce gradients across SMs
338
+ dX = dX.view(*shape)
339
+ dW = _dW.sum(dim=0).to(W.dtype)
340
+ dB = _dB.sum().to(W.dtype)
341
+
342
+ return dX, dW, dB
343
+
344
+
345
+ class LigerPolyNormFunction(torch.autograd.Function):
346
+ """
347
+ PolyNorm Function with forward and backward pass
348
+
349
+ PolyNorm formula:
350
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
351
+ where norm(u) = u / sqrt(mean(u²) + ε)
352
+
353
+ Backward uses closed-form gradient:
354
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
355
+ """
356
+
357
+ @staticmethod
358
+ @ensure_contiguous
359
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
360
+ """
361
+ Args:
362
+ X: input tensor of shape (B, T, H) or (BxT, H)
363
+ W: weight tensor of shape (3,) for [w0, w1, w2]
364
+ B: bias scalar
365
+ eps: epsilon for numerical stability
366
+ in_place: whether to in-place modify grad_output in backward (saves memory)
367
+
368
+ Returns:
369
+ Y: output tensor of same shape as X
370
+ """
371
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
372
+ ctx.BLOCK_SIZE = BLOCK_SIZE
373
+ ctx.num_warps = num_warps
374
+ ctx.in_place = in_place
375
+ ctx.save_for_backward(X, W, RSTD)
376
+ return Y
377
+
378
+ @staticmethod
379
+ @ensure_contiguous
380
+ def backward(ctx, grad_output):
381
+ """
382
+ Args:
383
+ grad_output: gradient of output
384
+
385
+ Returns:
386
+ dX, dW, dB: gradients w.r.t. X, W, B
387
+ """
388
+ X, W, RSTD = ctx.saved_tensors
389
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
390
+ return dX, dW, dB, None, None