liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +65 -11
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -13
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +86 -66
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +27 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +19 -5
  24. liger_kernel/transformers/model/gemma.py +17 -6
  25. liger_kernel/transformers/model/gemma2.py +14 -5
  26. liger_kernel/transformers/model/gemma3.py +25 -12
  27. liger_kernel/transformers/model/glm4.py +16 -4
  28. liger_kernel/transformers/model/glm4v.py +16 -4
  29. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +12 -5
  32. liger_kernel/transformers/model/llama.py +14 -5
  33. liger_kernel/transformers/model/llama4.py +16 -4
  34. liger_kernel/transformers/model/llava.py +12 -4
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +15 -6
  37. liger_kernel/transformers/model/mixtral.py +16 -7
  38. liger_kernel/transformers/model/mllama.py +12 -4
  39. liger_kernel/transformers/model/olmo2.py +16 -4
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +22 -5
  43. liger_kernel/transformers/model/phi3.py +14 -7
  44. liger_kernel/transformers/model/qwen2.py +16 -3
  45. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  46. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  47. liger_kernel/transformers/model/qwen3.py +20 -5
  48. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +15 -6
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +594 -19
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +4 -1
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  64. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+
22
+ @triton.jit
23
+ def _poly_norm_forward_kernel(
24
+ Y_ptr,
25
+ Y_row_stride,
26
+ X_ptr,
27
+ X_row_stride,
28
+ W_ptr, # weight: [3] for [w0, w1, w2]
29
+ B_ptr, # bias: scalar
30
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
31
+ RSTD_row_stride,
32
+ n_cols,
33
+ eps,
34
+ BLOCK_SIZE: tl.constexpr,
35
+ ):
36
+ """
37
+ PolyNorm formula:
38
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
39
+ where norm(u) = u / sqrt(mean(u²) + ε)
40
+
41
+ Reference:
42
+ 1. https://github.com/BryceZhuo/PolyCom/
43
+ 2. https://arxiv.org/pdf/2411.03884
44
+
45
+ Cache rstd values for backward pass
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Load pointers
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ RSTD_ptr += row_idx * RSTD_row_stride
55
+
56
+ # Load input row
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
58
+
59
+ # Load weights and bias
60
+ w0 = tl.load(W_ptr + 0)
61
+ w1 = tl.load(W_ptr + 1)
62
+ w2 = tl.load(W_ptr + 2)
63
+ b = tl.load(B_ptr)
64
+
65
+ # Compute x³, x², x
66
+ X_pow3 = X_row * X_row * X_row
67
+ X_pow2 = X_row * X_row
68
+ X_pow1 = X_row
69
+
70
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
71
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
72
+ rstd_3 = rsqrt(mean_square_3 + eps)
73
+ norm_x3 = X_pow3 * rstd_3
74
+
75
+ # Compute norm(x²)
76
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
77
+ rstd_2 = rsqrt(mean_square_2 + eps)
78
+ norm_x2 = X_pow2 * rstd_2
79
+
80
+ # Compute norm(x)
81
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
82
+ rstd_1 = rsqrt(mean_square_1 + eps)
83
+ norm_x1 = X_pow1 * rstd_1
84
+
85
+ # Cache rstd values for backward
86
+ tl.store(RSTD_ptr + 0, rstd_3)
87
+ tl.store(RSTD_ptr + 1, rstd_2)
88
+ tl.store(RSTD_ptr + 2, rstd_1)
89
+
90
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
91
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
92
+
93
+ # Store output
94
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
95
+
96
+
97
+ @triton.jit
98
+ def _poly_norm_backward_kernel(
99
+ dY_ptr,
100
+ dY_row_stride,
101
+ dX_ptr,
102
+ dX_row_stride,
103
+ X_ptr,
104
+ X_row_stride,
105
+ W_ptr,
106
+ RSTD_ptr,
107
+ RSTD_row_stride,
108
+ dW_ptr, # shape: (n_programs, 3)
109
+ dW_row_stride,
110
+ dB_ptr, # shape: (n_programs,)
111
+ n_rows,
112
+ n_cols,
113
+ rows_per_program: tl.constexpr,
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ """
117
+ PolyNorm Backward Kernel Gradient:
118
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
119
+
120
+ where:
121
+ - D_p = RMS(x^p) = 1/rstd_p
122
+ - S_p = sum(grad * x^p) over the row
123
+ - d = n_cols
124
+ - p ∈ {3, 2, 1}
125
+ """
126
+ row_block_id = tl.program_id(0).to(tl.int64)
127
+ row_start = row_block_id * rows_per_program
128
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
+ col_offsets = tl.arange(0, BLOCK_SIZE)
130
+ mask = col_offsets < n_cols
131
+
132
+ # Initialize accumulators for weight and bias gradients (scalars)
133
+ dW0_acc = 0.0
134
+ dW1_acc = 0.0
135
+ dW2_acc = 0.0
136
+ dB_acc = 0.0
137
+
138
+ # Load weights
139
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
140
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
141
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
142
+
143
+ dY_ptr += row_start * dY_row_stride
144
+ dX_ptr += row_start * dX_row_stride
145
+ X_ptr += row_start * X_row_stride
146
+ RSTD_ptr += row_start * RSTD_row_stride
147
+
148
+ for _ in range(row_start, row_end):
149
+ # Load input and gradient
150
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
+
153
+ # Load cached rstd values
154
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
157
+
158
+ # Compute powers
159
+ X_pow3 = X_row * X_row * X_row
160
+ X_pow2 = X_row * X_row
161
+ X_pow1 = X_row
162
+
163
+ # Accumulate bias gradient: dB = sum(dY)
164
+ dB_acc += tl.sum(dY_row, axis=0)
165
+
166
+ # Compute gradient w.r.t. input using closed-form formula
167
+ # For p=3: ∂L/∂x from w0 * norm(x³)
168
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
169
+ grad_x_3 = w0 * (
170
+ 3.0 * X_pow2 * rstd_3 * dY_row
171
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
172
+ )
173
+
174
+ # For p=2: ∂L/∂x from w1 * norm(x²)
175
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
176
+ grad_x_2 = w1 * (
177
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
178
+ )
179
+
180
+ # For p=1: ∂L/∂x from w2 * norm(x)
181
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
182
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
183
+
184
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
185
+ dW0_acc += rstd_3 * S_3
186
+ dW1_acc += rstd_2 * S_2
187
+ dW2_acc += rstd_1 * S_1
188
+
189
+ # Total gradient
190
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
+
192
+ # Store gradient
193
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
+
195
+ # Update pointers
196
+ dY_ptr += dY_row_stride
197
+ dX_ptr += dX_row_stride
198
+ X_ptr += X_row_stride
199
+ RSTD_ptr += RSTD_row_stride
200
+
201
+ # Store accumulated gradients (scalars)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
203
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
204
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
205
+ tl.store(dB_ptr + row_block_id, dB_acc)
206
+
207
+
208
+ def poly_norm_forward(X, W, B, eps=1e-6):
209
+ """
210
+ PolyNorm Forward Pass
211
+
212
+ Args:
213
+ X: input tensor of shape (*, H) where H is hidden dimension
214
+ W: weight tensor of shape (3,) for [w0, w1, w2]
215
+ B: bias scalar tensor
216
+ eps: epsilon for numerical stability
217
+
218
+ Returns:
219
+ Y: output tensor of same shape as X
220
+ X: reshaped input (for backward)
221
+ RSTD: cached rstd values (for backward)
222
+ BLOCK_SIZE: block size used
223
+ num_warps: number of warps used
224
+ """
225
+ shape = X.shape
226
+ dim = shape[-1]
227
+ X = X.view(-1, dim)
228
+ n_rows, n_cols = X.shape
229
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
230
+
231
+ # RSTD is to cache rstd for each row
232
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
234
+
235
+ # Check constraints
236
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
237
+ assert B.numel() == 1, "Bias must be a scalar"
238
+
239
+ # XPU-specific optimization
240
+ kernel_args = {}
241
+ if X.device.type == "xpu":
242
+ kernel_args["grf_mode"] = "large"
243
+
244
+ # Launch kernel
245
+ _poly_norm_forward_kernel[(n_rows,)](
246
+ Y,
247
+ Y.stride(0),
248
+ X,
249
+ X.stride(0),
250
+ W,
251
+ B,
252
+ RSTD,
253
+ RSTD.stride(0),
254
+ n_cols,
255
+ eps,
256
+ BLOCK_SIZE=BLOCK_SIZE,
257
+ num_warps=num_warps,
258
+ **kernel_args,
259
+ )
260
+
261
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
262
+
263
+
264
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
265
+ """
266
+ PolyNorm Backward Pass
267
+
268
+ Args:
269
+ dY: gradient of output
270
+ X: input tensor (already reshaped to 2D)
271
+ W: weight tensor
272
+ RSTD: cached rstd values from forward
273
+ BLOCK_SIZE: block size from forward
274
+ num_warps: number of warps from forward
275
+ in_place: whether to in-place modify dY to store dX (saves memory)
276
+
277
+ Returns:
278
+ dX: gradient w.r.t. input
279
+ dW: gradient w.r.t. weight
280
+ dB: gradient w.r.t. bias
281
+ """
282
+ shape = dY.shape
283
+ dim = shape[-1]
284
+ dY = dY.view(-1, dim)
285
+ n_rows, n_cols = dY.shape
286
+
287
+ # Get number of SMs for parallelization
288
+ import math
289
+
290
+ sm_count = 1
291
+ if X.device.type == "cuda":
292
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
293
+ elif X.device.type == "xpu":
294
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295
+ elif X.device.type == "npu":
296
+ sm_count = get_npu_multi_processor_count()
297
+
298
+ # Allocate or reuse gradients
299
+ if in_place is True:
300
+ dX = dY
301
+ else:
302
+ dX = torch.zeros_like(dY)
303
+
304
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
305
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
306
+
307
+ rows_per_program = math.ceil(n_rows / sm_count)
308
+ grid = (sm_count,)
309
+
310
+ # XPU-specific optimization
311
+ kernel_args = {}
312
+ if X.device.type == "xpu":
313
+ kernel_args["grf_mode"] = "large"
314
+
315
+ # Launch backward kernel
316
+ _poly_norm_backward_kernel[grid](
317
+ dY,
318
+ dY.stride(0),
319
+ dX,
320
+ dX.stride(0),
321
+ X,
322
+ X.stride(0),
323
+ W,
324
+ RSTD,
325
+ RSTD.stride(0),
326
+ _dW,
327
+ _dW.stride(0),
328
+ _dB,
329
+ n_rows,
330
+ n_cols,
331
+ rows_per_program,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ num_warps=num_warps,
334
+ **kernel_args,
335
+ )
336
+
337
+ # Reduce gradients across SMs
338
+ dX = dX.view(*shape)
339
+ dW = _dW.sum(dim=0).to(W.dtype)
340
+ dB = _dB.sum().to(W.dtype)
341
+
342
+ return dX, dW, dB
343
+
344
+
345
+ class LigerPolyNormFunction(torch.autograd.Function):
346
+ """
347
+ PolyNorm Function with forward and backward pass
348
+
349
+ PolyNorm formula:
350
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
351
+ where norm(u) = u / sqrt(mean(u²) + ε)
352
+
353
+ Backward uses closed-form gradient:
354
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
355
+ """
356
+
357
+ @staticmethod
358
+ @ensure_contiguous
359
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
360
+ """
361
+ Args:
362
+ X: input tensor of shape (B, T, H) or (BxT, H)
363
+ W: weight tensor of shape (3,) for [w0, w1, w2]
364
+ B: bias scalar
365
+ eps: epsilon for numerical stability
366
+ in_place: whether to in-place modify grad_output in backward (saves memory)
367
+
368
+ Returns:
369
+ Y: output tensor of same shape as X
370
+ """
371
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
372
+ ctx.BLOCK_SIZE = BLOCK_SIZE
373
+ ctx.num_warps = num_warps
374
+ ctx.in_place = in_place
375
+ ctx.save_for_backward(X, W, RSTD)
376
+ return Y
377
+
378
+ @staticmethod
379
+ @ensure_contiguous
380
+ def backward(ctx, grad_output):
381
+ """
382
+ Args:
383
+ grad_output: gradient of output
384
+
385
+ Returns:
386
+ dX, dW, dB: gradients w.r.t. X, W, B
387
+ """
388
+ X, W, RSTD = ctx.saved_tensors
389
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
390
+ return dX, dW, dB, None, None
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -349,7 +351,8 @@ def _block_rms_norm_backward_kernel(
349
351
 
350
352
  # calculate the gradient of W
351
353
  if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
354
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
353
356
  else:
354
357
  # here X_row is already in fp32 (see previous if block)
355
358
  dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
@@ -449,6 +452,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
452
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
453
  elif X.device.type == "xpu":
451
454
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
455
+ elif X.device.type == "npu":
456
+ sm_count = get_npu_multi_processor_count()
452
457
 
453
458
  # fp32 for numerical stability especially.
454
459
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
15
  from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
16
  from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
17
  from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
18
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
19
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20
21
  from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
@@ -23,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
23
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
24
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
25
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
26
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
27
30
 
28
31
  # Static-only imports for IDEs and type checkers
@@ -39,6 +42,8 @@ if TYPE_CHECKING:
39
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
40
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
41
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
42
47
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
43
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
44
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -47,6 +52,7 @@ if TYPE_CHECKING:
47
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
48
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
49
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
50
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
51
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
52
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -54,7 +60,11 @@ if TYPE_CHECKING:
54
60
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
55
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
56
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
63
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
65
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
57
66
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
67
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
58
68
 
59
69
 
60
70
  # Check if 'transformers' is installed
@@ -109,6 +119,7 @@ def __getattr__(name: str):
109
119
  "apply_liger_kernel_to_mixtral",
110
120
  "apply_liger_kernel_to_mllama",
111
121
  "apply_liger_kernel_to_olmo2",
122
+ "apply_liger_kernel_to_olmo3",
112
123
  "apply_liger_kernel_to_paligemma",
113
124
  "apply_liger_kernel_to_phi3",
114
125
  "apply_liger_kernel_to_qwen2",
@@ -116,7 +127,13 @@ def __getattr__(name: str):
116
127
  "apply_liger_kernel_to_qwen2_vl",
117
128
  "apply_liger_kernel_to_qwen3",
118
129
  "apply_liger_kernel_to_qwen3_moe",
130
+ "apply_liger_kernel_to_qwen3_next",
131
+ "apply_liger_kernel_to_qwen3_vl",
132
+ "apply_liger_kernel_to_qwen3_vl_moe",
119
133
  "apply_liger_kernel_to_smollm3",
134
+ "apply_liger_kernel_to_smolvlm",
135
+ "apply_liger_kernel_to_hunyuan_v1_dense",
136
+ "apply_liger_kernel_to_hunyuan_v1_moe",
120
137
  }
121
138
 
122
139
  if name in monkey_patch_symbols:
@@ -137,6 +154,7 @@ __all__ = [
137
154
  "LigerJSD",
138
155
  "LigerLayerNorm",
139
156
  "LigerFusedAddRMSNorm",
157
+ "LigerPolyNorm",
140
158
  "LigerRMSNorm",
141
159
  "liger_rotary_pos_emb",
142
160
  "liger_llama4_text_rotary_pos_emb",
@@ -145,6 +163,8 @@ __all__ = [
145
163
  "LigerPhi3SwiGLUMLP",
146
164
  "LigerQwen3MoeSwiGLUMLP",
147
165
  "LigerSwiGLUMLP",
166
+ "LigerTiledGEGLUMLP",
167
+ "LigerTiledSwiGLUMLP",
148
168
  "LigerTVDLoss",
149
169
  "LigerKLDIVLoss",
150
170
  "LigerMultiTokenAttention",
@@ -176,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
176
196
  "apply_liger_kernel_to_mixtral",
177
197
  "apply_liger_kernel_to_mllama",
178
198
  "apply_liger_kernel_to_olmo2",
199
+ "apply_liger_kernel_to_olmo3",
179
200
  "apply_liger_kernel_to_paligemma",
180
201
  "apply_liger_kernel_to_phi3",
181
202
  "apply_liger_kernel_to_qwen2",
@@ -183,6 +204,12 @@ if _TRANSFORMERS_AVAILABLE:
183
204
  "apply_liger_kernel_to_qwen2_vl",
184
205
  "apply_liger_kernel_to_qwen3",
185
206
  "apply_liger_kernel_to_qwen3_moe",
207
+ "apply_liger_kernel_to_qwen3_next",
208
+ "apply_liger_kernel_to_qwen3_vl",
209
+ "apply_liger_kernel_to_qwen3_vl_moe",
186
210
  "apply_liger_kernel_to_smollm3",
211
+ "apply_liger_kernel_to_smolvlm",
212
+ "apply_liger_kernel_to_hunyuan_v1_dense",
213
+ "apply_liger_kernel_to_hunyuan_v1_moe",
187
214
  ]
188
215
  )