liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+
22
+ @triton.jit
23
+ def _poly_norm_forward_kernel(
24
+ Y_ptr,
25
+ Y_row_stride,
26
+ X_ptr,
27
+ X_row_stride,
28
+ W_ptr, # weight: [3] for [w0, w1, w2]
29
+ B_ptr, # bias: scalar
30
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
31
+ RSTD_row_stride,
32
+ n_cols,
33
+ eps,
34
+ BLOCK_SIZE: tl.constexpr,
35
+ ):
36
+ """
37
+ PolyNorm formula:
38
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
39
+ where norm(u) = u / sqrt(mean(u²) + ε)
40
+
41
+ Reference:
42
+ 1. https://github.com/BryceZhuo/PolyCom/
43
+ 2. https://arxiv.org/pdf/2411.03884
44
+
45
+ Cache rstd values for backward pass
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Load pointers
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ RSTD_ptr += row_idx * RSTD_row_stride
55
+
56
+ # Load input row
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
58
+
59
+ # Load weights and bias
60
+ w0 = tl.load(W_ptr + 0)
61
+ w1 = tl.load(W_ptr + 1)
62
+ w2 = tl.load(W_ptr + 2)
63
+ b = tl.load(B_ptr)
64
+
65
+ # Compute x³, x², x
66
+ X_pow3 = X_row * X_row * X_row
67
+ X_pow2 = X_row * X_row
68
+ X_pow1 = X_row
69
+
70
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
71
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
72
+ rstd_3 = rsqrt(mean_square_3 + eps)
73
+ norm_x3 = X_pow3 * rstd_3
74
+
75
+ # Compute norm(x²)
76
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
77
+ rstd_2 = rsqrt(mean_square_2 + eps)
78
+ norm_x2 = X_pow2 * rstd_2
79
+
80
+ # Compute norm(x)
81
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
82
+ rstd_1 = rsqrt(mean_square_1 + eps)
83
+ norm_x1 = X_pow1 * rstd_1
84
+
85
+ # Cache rstd values for backward
86
+ tl.store(RSTD_ptr + 0, rstd_3)
87
+ tl.store(RSTD_ptr + 1, rstd_2)
88
+ tl.store(RSTD_ptr + 2, rstd_1)
89
+
90
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
91
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
92
+
93
+ # Store output
94
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
95
+
96
+
97
+ @triton.jit
98
+ def _poly_norm_backward_kernel(
99
+ dY_ptr,
100
+ dY_row_stride,
101
+ dX_ptr,
102
+ dX_row_stride,
103
+ X_ptr,
104
+ X_row_stride,
105
+ W_ptr,
106
+ RSTD_ptr,
107
+ RSTD_row_stride,
108
+ dW_ptr, # shape: (n_programs, 3)
109
+ dW_row_stride,
110
+ dB_ptr, # shape: (n_programs,)
111
+ n_rows,
112
+ n_cols,
113
+ rows_per_program: tl.constexpr,
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ """
117
+ PolyNorm Backward Kernel Gradient:
118
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
119
+
120
+ where:
121
+ - D_p = RMS(x^p) = 1/rstd_p
122
+ - S_p = sum(grad * x^p) over the row
123
+ - d = n_cols
124
+ - p ∈ {3, 2, 1}
125
+ """
126
+ row_block_id = tl.program_id(0).to(tl.int64)
127
+ row_start = row_block_id * rows_per_program
128
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
+ col_offsets = tl.arange(0, BLOCK_SIZE)
130
+ mask = col_offsets < n_cols
131
+
132
+ # Initialize accumulators for weight and bias gradients (scalars)
133
+ dW0_acc = 0.0
134
+ dW1_acc = 0.0
135
+ dW2_acc = 0.0
136
+ dB_acc = 0.0
137
+
138
+ # Load weights
139
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
140
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
141
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
142
+
143
+ dY_ptr += row_start * dY_row_stride
144
+ dX_ptr += row_start * dX_row_stride
145
+ X_ptr += row_start * X_row_stride
146
+ RSTD_ptr += row_start * RSTD_row_stride
147
+
148
+ for _ in range(row_start, row_end):
149
+ # Load input and gradient
150
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
+
153
+ # Load cached rstd values
154
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
157
+
158
+ # Compute powers
159
+ X_pow3 = X_row * X_row * X_row
160
+ X_pow2 = X_row * X_row
161
+ X_pow1 = X_row
162
+
163
+ # Accumulate bias gradient: dB = sum(dY)
164
+ dB_acc += tl.sum(dY_row, axis=0)
165
+
166
+ # Compute gradient w.r.t. input using closed-form formula
167
+ # For p=3: ∂L/∂x from w0 * norm(x³)
168
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
169
+ grad_x_3 = w0 * (
170
+ 3.0 * X_pow2 * rstd_3 * dY_row
171
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
172
+ )
173
+
174
+ # For p=2: ∂L/∂x from w1 * norm(x²)
175
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
176
+ grad_x_2 = w1 * (
177
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
178
+ )
179
+
180
+ # For p=1: ∂L/∂x from w2 * norm(x)
181
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
182
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
183
+
184
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
185
+ dW0_acc += rstd_3 * S_3
186
+ dW1_acc += rstd_2 * S_2
187
+ dW2_acc += rstd_1 * S_1
188
+
189
+ # Total gradient
190
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
+
192
+ # Store gradient
193
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
+
195
+ # Update pointers
196
+ dY_ptr += dY_row_stride
197
+ dX_ptr += dX_row_stride
198
+ X_ptr += X_row_stride
199
+ RSTD_ptr += RSTD_row_stride
200
+
201
+ # Store accumulated gradients (scalars)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
203
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
204
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
205
+ tl.store(dB_ptr + row_block_id, dB_acc)
206
+
207
+
208
+ def poly_norm_forward(X, W, B, eps=1e-6):
209
+ """
210
+ PolyNorm Forward Pass
211
+
212
+ Args:
213
+ X: input tensor of shape (*, H) where H is hidden dimension
214
+ W: weight tensor of shape (3,) for [w0, w1, w2]
215
+ B: bias scalar tensor
216
+ eps: epsilon for numerical stability
217
+
218
+ Returns:
219
+ Y: output tensor of same shape as X
220
+ X: reshaped input (for backward)
221
+ RSTD: cached rstd values (for backward)
222
+ BLOCK_SIZE: block size used
223
+ num_warps: number of warps used
224
+ """
225
+ shape = X.shape
226
+ dim = shape[-1]
227
+ X = X.view(-1, dim)
228
+ n_rows, n_cols = X.shape
229
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
230
+
231
+ # RSTD is to cache rstd for each row
232
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
234
+
235
+ # Check constraints
236
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
237
+ assert B.numel() == 1, "Bias must be a scalar"
238
+
239
+ # XPU-specific optimization
240
+ kernel_args = {}
241
+ if X.device.type == "xpu":
242
+ kernel_args["grf_mode"] = "large"
243
+
244
+ # Launch kernel
245
+ _poly_norm_forward_kernel[(n_rows,)](
246
+ Y,
247
+ Y.stride(0),
248
+ X,
249
+ X.stride(0),
250
+ W,
251
+ B,
252
+ RSTD,
253
+ RSTD.stride(0),
254
+ n_cols,
255
+ eps,
256
+ BLOCK_SIZE=BLOCK_SIZE,
257
+ num_warps=num_warps,
258
+ **kernel_args,
259
+ )
260
+
261
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
262
+
263
+
264
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
265
+ """
266
+ PolyNorm Backward Pass
267
+
268
+ Args:
269
+ dY: gradient of output
270
+ X: input tensor (already reshaped to 2D)
271
+ W: weight tensor
272
+ RSTD: cached rstd values from forward
273
+ BLOCK_SIZE: block size from forward
274
+ num_warps: number of warps from forward
275
+ in_place: whether to in-place modify dY to store dX (saves memory)
276
+
277
+ Returns:
278
+ dX: gradient w.r.t. input
279
+ dW: gradient w.r.t. weight
280
+ dB: gradient w.r.t. bias
281
+ """
282
+ shape = dY.shape
283
+ dim = shape[-1]
284
+ dY = dY.view(-1, dim)
285
+ n_rows, n_cols = dY.shape
286
+
287
+ # Get number of SMs for parallelization
288
+ import math
289
+
290
+ sm_count = 1
291
+ if X.device.type == "cuda":
292
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
293
+ elif X.device.type == "xpu":
294
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295
+ elif X.device.type == "npu":
296
+ sm_count = get_npu_multi_processor_count()
297
+
298
+ # Allocate or reuse gradients
299
+ if in_place is True:
300
+ dX = dY
301
+ else:
302
+ dX = torch.zeros_like(dY)
303
+
304
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
305
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
306
+
307
+ rows_per_program = math.ceil(n_rows / sm_count)
308
+ grid = (sm_count,)
309
+
310
+ # XPU-specific optimization
311
+ kernel_args = {}
312
+ if X.device.type == "xpu":
313
+ kernel_args["grf_mode"] = "large"
314
+
315
+ # Launch backward kernel
316
+ _poly_norm_backward_kernel[grid](
317
+ dY,
318
+ dY.stride(0),
319
+ dX,
320
+ dX.stride(0),
321
+ X,
322
+ X.stride(0),
323
+ W,
324
+ RSTD,
325
+ RSTD.stride(0),
326
+ _dW,
327
+ _dW.stride(0),
328
+ _dB,
329
+ n_rows,
330
+ n_cols,
331
+ rows_per_program,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ num_warps=num_warps,
334
+ **kernel_args,
335
+ )
336
+
337
+ # Reduce gradients across SMs
338
+ dX = dX.view(*shape)
339
+ dW = _dW.sum(dim=0).to(W.dtype)
340
+ dB = _dB.sum().to(W.dtype)
341
+
342
+ return dX, dW, dB
343
+
344
+
345
+ class LigerPolyNormFunction(torch.autograd.Function):
346
+ """
347
+ PolyNorm Function with forward and backward pass
348
+
349
+ PolyNorm formula:
350
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
351
+ where norm(u) = u / sqrt(mean(u²) + ε)
352
+
353
+ Backward uses closed-form gradient:
354
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
355
+ """
356
+
357
+ @staticmethod
358
+ @ensure_contiguous
359
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
360
+ """
361
+ Args:
362
+ X: input tensor of shape (B, T, H) or (BxT, H)
363
+ W: weight tensor of shape (3,) for [w0, w1, w2]
364
+ B: bias scalar
365
+ eps: epsilon for numerical stability
366
+ in_place: whether to in-place modify grad_output in backward (saves memory)
367
+
368
+ Returns:
369
+ Y: output tensor of same shape as X
370
+ """
371
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
372
+ ctx.BLOCK_SIZE = BLOCK_SIZE
373
+ ctx.num_warps = num_warps
374
+ ctx.in_place = in_place
375
+ ctx.save_for_backward(X, W, RSTD)
376
+ return Y
377
+
378
+ @staticmethod
379
+ @ensure_contiguous
380
+ def backward(ctx, grad_output):
381
+ """
382
+ Args:
383
+ grad_output: gradient of output
384
+
385
+ Returns:
386
+ dX, dW, dB: gradients w.r.t. X, W, B
387
+ """
388
+ X, W, RSTD = ctx.saved_tensors
389
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
390
+ return dX, dW, dB, None, None