liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  7. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/layer_norm.py +124 -89
  14. liger_kernel/ops/llama4_rope.py +225 -0
  15. liger_kernel/ops/poly_norm.py +386 -0
  16. liger_kernel/ops/rms_norm.py +2 -2
  17. liger_kernel/ops/rope.py +1 -1
  18. liger_kernel/ops/swiglu.py +1 -1
  19. liger_kernel/ops/tiled_mlp.py +136 -0
  20. liger_kernel/transformers/__init__.py +50 -0
  21. liger_kernel/transformers/cross_entropy.py +8 -3
  22. liger_kernel/transformers/experimental/__init__.py +5 -0
  23. liger_kernel/transformers/functional.py +38 -6
  24. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  25. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  26. liger_kernel/transformers/llama4_rope.py +93 -0
  27. liger_kernel/transformers/model/falcon_h1.py +122 -0
  28. liger_kernel/transformers/model/gemma.py +28 -8
  29. liger_kernel/transformers/model/gemma2.py +31 -8
  30. liger_kernel/transformers/model/gemma3.py +100 -110
  31. liger_kernel/transformers/model/glm4.py +18 -5
  32. liger_kernel/transformers/model/glm4v.py +163 -0
  33. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  34. liger_kernel/transformers/model/internvl.py +157 -0
  35. liger_kernel/transformers/model/llama.py +26 -7
  36. liger_kernel/transformers/model/llama4.py +121 -0
  37. liger_kernel/transformers/model/llava.py +18 -6
  38. liger_kernel/transformers/model/loss_utils.py +34 -3
  39. liger_kernel/transformers/model/mistral.py +17 -10
  40. liger_kernel/transformers/model/mixtral.py +24 -9
  41. liger_kernel/transformers/model/mllama.py +18 -7
  42. liger_kernel/transformers/model/olmo2.py +18 -5
  43. liger_kernel/transformers/model/output_classes.py +147 -0
  44. liger_kernel/transformers/model/paligemma.py +41 -5
  45. liger_kernel/transformers/model/phi3.py +24 -159
  46. liger_kernel/transformers/model/qwen2.py +26 -4
  47. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  48. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  49. liger_kernel/transformers/model/qwen3.py +22 -6
  50. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  51. liger_kernel/transformers/model/qwen3_next.py +146 -0
  52. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  53. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  54. liger_kernel/transformers/model/smollm3.py +199 -0
  55. liger_kernel/transformers/model/smolvlm.py +158 -0
  56. liger_kernel/transformers/monkey_patch.py +1090 -116
  57. liger_kernel/transformers/multi_token_attention.py +1 -1
  58. liger_kernel/transformers/poly_norm.py +42 -0
  59. liger_kernel/transformers/rms_norm.py +7 -0
  60. liger_kernel/transformers/rope.py +43 -0
  61. liger_kernel/transformers/tiled_mlp.py +133 -0
  62. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
  63. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  64. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  65. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+
20
+ @triton.jit
21
+ def _poly_norm_forward_kernel(
22
+ Y_ptr,
23
+ Y_row_stride,
24
+ X_ptr,
25
+ X_row_stride,
26
+ W_ptr, # weight: [3] for [w0, w1, w2]
27
+ B_ptr, # bias: scalar
28
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
29
+ RSTD_row_stride,
30
+ n_cols,
31
+ eps,
32
+ BLOCK_SIZE: tl.constexpr,
33
+ ):
34
+ """
35
+ PolyNorm formula:
36
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
37
+ where norm(u) = u / sqrt(mean(u²) + ε)
38
+
39
+ Reference:
40
+ 1. https://github.com/BryceZhuo/PolyCom/
41
+ 2. https://arxiv.org/pdf/2411.03884
42
+
43
+ Cache rstd values for backward pass
44
+ """
45
+ row_idx = tl.program_id(0).to(tl.int64)
46
+ col_offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = col_offsets < n_cols
48
+
49
+ # Load pointers
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ RSTD_ptr += row_idx * RSTD_row_stride
53
+
54
+ # Load input row
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
56
+
57
+ # Load weights and bias
58
+ w0 = tl.load(W_ptr + 0)
59
+ w1 = tl.load(W_ptr + 1)
60
+ w2 = tl.load(W_ptr + 2)
61
+ b = tl.load(B_ptr)
62
+
63
+ # Compute x³, x², x
64
+ X_pow3 = X_row * X_row * X_row
65
+ X_pow2 = X_row * X_row
66
+ X_pow1 = X_row
67
+
68
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
69
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
70
+ rstd_3 = rsqrt(mean_square_3 + eps)
71
+ norm_x3 = X_pow3 * rstd_3
72
+
73
+ # Compute norm(x²)
74
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
75
+ rstd_2 = rsqrt(mean_square_2 + eps)
76
+ norm_x2 = X_pow2 * rstd_2
77
+
78
+ # Compute norm(x)
79
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
80
+ rstd_1 = rsqrt(mean_square_1 + eps)
81
+ norm_x1 = X_pow1 * rstd_1
82
+
83
+ # Cache rstd values for backward
84
+ tl.store(RSTD_ptr + 0, rstd_3)
85
+ tl.store(RSTD_ptr + 1, rstd_2)
86
+ tl.store(RSTD_ptr + 2, rstd_1)
87
+
88
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
89
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
90
+
91
+ # Store output
92
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def _poly_norm_backward_kernel(
97
+ dY_ptr,
98
+ dY_row_stride,
99
+ dX_ptr,
100
+ dX_row_stride,
101
+ X_ptr,
102
+ X_row_stride,
103
+ W_ptr,
104
+ RSTD_ptr,
105
+ RSTD_row_stride,
106
+ dW_ptr, # shape: (n_programs, 3)
107
+ dW_row_stride,
108
+ dB_ptr, # shape: (n_programs,)
109
+ n_rows,
110
+ n_cols,
111
+ rows_per_program: tl.constexpr,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ ):
114
+ """
115
+ PolyNorm Backward Kernel Gradient:
116
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
117
+
118
+ where:
119
+ - D_p = RMS(x^p) = 1/rstd_p
120
+ - S_p = sum(grad * x^p) over the row
121
+ - d = n_cols
122
+ - p ∈ {3, 2, 1}
123
+ """
124
+ row_block_id = tl.program_id(0).to(tl.int64)
125
+ row_start = row_block_id * rows_per_program
126
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
127
+ col_offsets = tl.arange(0, BLOCK_SIZE)
128
+ mask = col_offsets < n_cols
129
+
130
+ # Initialize accumulators for weight and bias gradients (scalars)
131
+ dW0_acc = 0.0
132
+ dW1_acc = 0.0
133
+ dW2_acc = 0.0
134
+ dB_acc = 0.0
135
+
136
+ # Load weights
137
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
138
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
139
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
140
+
141
+ dY_ptr += row_start * dY_row_stride
142
+ dX_ptr += row_start * dX_row_stride
143
+ X_ptr += row_start * X_row_stride
144
+ RSTD_ptr += row_start * RSTD_row_stride
145
+
146
+ for _ in range(row_start, row_end):
147
+ # Load input and gradient
148
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+
151
+ # Load cached rstd values
152
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
155
+
156
+ # Compute powers
157
+ X_pow3 = X_row * X_row * X_row
158
+ X_pow2 = X_row * X_row
159
+ X_pow1 = X_row
160
+
161
+ # Accumulate bias gradient: dB = sum(dY)
162
+ dB_acc += tl.sum(dY_row, axis=0)
163
+
164
+ # Compute gradient w.r.t. input using closed-form formula
165
+ # For p=3: ∂L/∂x from w0 * norm(x³)
166
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
167
+ grad_x_3 = w0 * (
168
+ 3.0 * X_pow2 * rstd_3 * dY_row
169
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
170
+ )
171
+
172
+ # For p=2: ∂L/∂x from w1 * norm(x²)
173
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
174
+ grad_x_2 = w1 * (
175
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
176
+ )
177
+
178
+ # For p=1: ∂L/∂x from w2 * norm(x)
179
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
180
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
181
+
182
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
183
+ dW0_acc += rstd_3 * S_3
184
+ dW1_acc += rstd_2 * S_2
185
+ dW2_acc += rstd_1 * S_1
186
+
187
+ # Total gradient
188
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
+
190
+ # Store gradient
191
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
+
193
+ # Update pointers
194
+ dY_ptr += dY_row_stride
195
+ dX_ptr += dX_row_stride
196
+ X_ptr += X_row_stride
197
+ RSTD_ptr += RSTD_row_stride
198
+
199
+ # Store accumulated gradients (scalars)
200
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
201
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
203
+ tl.store(dB_ptr + row_block_id, dB_acc)
204
+
205
+
206
+ def poly_norm_forward(X, W, B, eps=1e-6):
207
+ """
208
+ PolyNorm Forward Pass
209
+
210
+ Args:
211
+ X: input tensor of shape (*, H) where H is hidden dimension
212
+ W: weight tensor of shape (3,) for [w0, w1, w2]
213
+ B: bias scalar tensor
214
+ eps: epsilon for numerical stability
215
+
216
+ Returns:
217
+ Y: output tensor of same shape as X
218
+ X: reshaped input (for backward)
219
+ RSTD: cached rstd values (for backward)
220
+ BLOCK_SIZE: block size used
221
+ num_warps: number of warps used
222
+ """
223
+ shape = X.shape
224
+ dim = shape[-1]
225
+ X = X.view(-1, dim)
226
+ n_rows, n_cols = X.shape
227
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
228
+
229
+ # RSTD is to cache rstd for each row
230
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
231
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
232
+
233
+ # Check constraints
234
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
235
+ assert B.numel() == 1, "Bias must be a scalar"
236
+
237
+ # XPU-specific optimization
238
+ kernel_args = {}
239
+ if X.device.type == "xpu":
240
+ kernel_args["grf_mode"] = "large"
241
+
242
+ # Launch kernel
243
+ _poly_norm_forward_kernel[(n_rows,)](
244
+ Y,
245
+ Y.stride(0),
246
+ X,
247
+ X.stride(0),
248
+ W,
249
+ B,
250
+ RSTD,
251
+ RSTD.stride(0),
252
+ n_cols,
253
+ eps,
254
+ BLOCK_SIZE=BLOCK_SIZE,
255
+ num_warps=num_warps,
256
+ **kernel_args,
257
+ )
258
+
259
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
260
+
261
+
262
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
263
+ """
264
+ PolyNorm Backward Pass
265
+
266
+ Args:
267
+ dY: gradient of output
268
+ X: input tensor (already reshaped to 2D)
269
+ W: weight tensor
270
+ RSTD: cached rstd values from forward
271
+ BLOCK_SIZE: block size from forward
272
+ num_warps: number of warps from forward
273
+ in_place: whether to in-place modify dY to store dX (saves memory)
274
+
275
+ Returns:
276
+ dX: gradient w.r.t. input
277
+ dW: gradient w.r.t. weight
278
+ dB: gradient w.r.t. bias
279
+ """
280
+ shape = dY.shape
281
+ dim = shape[-1]
282
+ dY = dY.view(-1, dim)
283
+ n_rows, n_cols = dY.shape
284
+
285
+ # Get number of SMs for parallelization
286
+ import math
287
+
288
+ sm_count = 1
289
+ if X.device.type == "cuda":
290
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
+ elif X.device.type == "xpu":
292
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
293
+
294
+ # Allocate or reuse gradients
295
+ if in_place is True:
296
+ dX = dY
297
+ else:
298
+ dX = torch.zeros_like(dY)
299
+
300
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
301
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
302
+
303
+ rows_per_program = math.ceil(n_rows / sm_count)
304
+ grid = (sm_count,)
305
+
306
+ # XPU-specific optimization
307
+ kernel_args = {}
308
+ if X.device.type == "xpu":
309
+ kernel_args["grf_mode"] = "large"
310
+
311
+ # Launch backward kernel
312
+ _poly_norm_backward_kernel[grid](
313
+ dY,
314
+ dY.stride(0),
315
+ dX,
316
+ dX.stride(0),
317
+ X,
318
+ X.stride(0),
319
+ W,
320
+ RSTD,
321
+ RSTD.stride(0),
322
+ _dW,
323
+ _dW.stride(0),
324
+ _dB,
325
+ n_rows,
326
+ n_cols,
327
+ rows_per_program,
328
+ BLOCK_SIZE=BLOCK_SIZE,
329
+ num_warps=num_warps,
330
+ **kernel_args,
331
+ )
332
+
333
+ # Reduce gradients across SMs
334
+ dX = dX.view(*shape)
335
+ dW = _dW.sum(dim=0).to(W.dtype)
336
+ dB = _dB.sum().to(W.dtype)
337
+
338
+ return dX, dW, dB
339
+
340
+
341
+ class LigerPolyNormFunction(torch.autograd.Function):
342
+ """
343
+ PolyNorm Function with forward and backward pass
344
+
345
+ PolyNorm formula:
346
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
347
+ where norm(u) = u / sqrt(mean(u²) + ε)
348
+
349
+ Backward uses closed-form gradient:
350
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
351
+ """
352
+
353
+ @staticmethod
354
+ @ensure_contiguous
355
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
356
+ """
357
+ Args:
358
+ X: input tensor of shape (B, T, H) or (BxT, H)
359
+ W: weight tensor of shape (3,) for [w0, w1, w2]
360
+ B: bias scalar
361
+ eps: epsilon for numerical stability
362
+ in_place: whether to in-place modify grad_output in backward (saves memory)
363
+
364
+ Returns:
365
+ Y: output tensor of same shape as X
366
+ """
367
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
368
+ ctx.BLOCK_SIZE = BLOCK_SIZE
369
+ ctx.num_warps = num_warps
370
+ ctx.in_place = in_place
371
+ ctx.save_for_backward(X, W, RSTD)
372
+ return Y
373
+
374
+ @staticmethod
375
+ @ensure_contiguous
376
+ def backward(ctx, grad_output):
377
+ """
378
+ Args:
379
+ grad_output: gradient of output
380
+
381
+ Returns:
382
+ dX, dW, dB: gradients w.r.t. X, W, B
383
+ """
384
+ X, W, RSTD = ctx.saved_tensors
385
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
386
+ return dX, dW, dB, None, None
@@ -63,7 +63,7 @@ def _rms_norm_forward_kernel(
63
63
  3. https://arxiv.org/pdf/1910.07467
64
64
  """
65
65
 
66
- row_idx = tl.program_id(0)
66
+ row_idx = tl.program_id(0).to(tl.int64)
67
67
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
68
  mask = col_offsets < n_cols
69
69
 
@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
137
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
138
  """
139
139
 
140
- row_block_id = tl.program_id(0)
140
+ row_block_id = tl.program_id(0).to(tl.int64)
141
141
  row_start = row_block_id * rows_per_program
142
142
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
liger_kernel/ops/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
32
 
33
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
36
36
 
37
37
  # locate start address
38
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
@@ -5,17 +5,27 @@ from typing import TYPE_CHECKING
5
5
  # Always-safe imports (independent of 'transformers')
6
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
7
7
  from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
8
9
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
9
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
10
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
11
12
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13
+ from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
12
14
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
+ from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
+ from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
13
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
14
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
22
+ from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
15
23
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
16
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
17
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
18
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
19
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
20
30
 
21
31
  # Static-only imports for IDEs and type checkers
@@ -23,13 +33,18 @@ if TYPE_CHECKING:
23
33
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
24
34
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
25
35
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
26
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
27
38
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
28
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
29
40
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
30
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
42
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
31
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
32
46
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
33
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
34
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
35
50
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
@@ -42,6 +57,11 @@ if TYPE_CHECKING:
42
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
43
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
44
59
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
61
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
62
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
63
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
45
65
 
46
66
 
47
67
  # Check if 'transformers' is installed
@@ -79,14 +99,19 @@ def __getattr__(name: str):
79
99
  monkey_patch_symbols = {
80
100
  "_apply_liger_kernel",
81
101
  "_apply_liger_kernel_to_instance",
102
+ "apply_liger_kernel_to_falcon_h1",
82
103
  "apply_liger_kernel_to_gemma",
83
104
  "apply_liger_kernel_to_gemma2",
84
105
  "apply_liger_kernel_to_gemma3",
85
106
  "apply_liger_kernel_to_gemma3_text",
86
107
  "apply_liger_kernel_to_glm4",
108
+ "apply_liger_kernel_to_glm4v",
109
+ "apply_liger_kernel_to_glm4v_moe",
87
110
  "apply_liger_kernel_to_granite",
111
+ "apply_liger_kernel_to_internvl",
88
112
  "apply_liger_kernel_to_llama",
89
113
  "apply_liger_kernel_to_llava",
114
+ "apply_liger_kernel_to_llama4",
90
115
  "apply_liger_kernel_to_mistral",
91
116
  "apply_liger_kernel_to_mixtral",
92
117
  "apply_liger_kernel_to_mllama",
@@ -98,6 +123,11 @@ def __getattr__(name: str):
98
123
  "apply_liger_kernel_to_qwen2_vl",
99
124
  "apply_liger_kernel_to_qwen3",
100
125
  "apply_liger_kernel_to_qwen3_moe",
126
+ "apply_liger_kernel_to_qwen3_next",
127
+ "apply_liger_kernel_to_qwen3_vl",
128
+ "apply_liger_kernel_to_qwen3_vl_moe",
129
+ "apply_liger_kernel_to_smollm3",
130
+ "apply_liger_kernel_to_smolvlm",
101
131
  }
102
132
 
103
133
  if name in monkey_patch_symbols:
@@ -117,13 +147,23 @@ __all__ = [
117
147
  "LigerGEGLUMLP",
118
148
  "LigerJSD",
119
149
  "LigerLayerNorm",
150
+ "LigerFusedAddRMSNorm",
151
+ "LigerPolyNorm",
120
152
  "LigerRMSNorm",
121
153
  "liger_rotary_pos_emb",
154
+ "liger_llama4_text_rotary_pos_emb",
155
+ "liger_llama4_vision_rotary_pos_emb",
122
156
  "LigerBlockSparseTop2MLP",
123
157
  "LigerPhi3SwiGLUMLP",
124
158
  "LigerQwen3MoeSwiGLUMLP",
125
159
  "LigerSwiGLUMLP",
160
+ "LigerTiledGEGLUMLP",
161
+ "LigerTiledSwiGLUMLP",
126
162
  "LigerTVDLoss",
163
+ "LigerKLDIVLoss",
164
+ "LigerMultiTokenAttention",
165
+ "LigerSoftmax",
166
+ "LigerSparsemax",
127
167
  ]
128
168
 
129
169
  # Add transformer-dependent symbols only if available
@@ -133,14 +173,19 @@ if _TRANSFORMERS_AVAILABLE:
133
173
  "AutoLigerKernelForCausalLM",
134
174
  "_apply_liger_kernel",
135
175
  "_apply_liger_kernel_to_instance",
176
+ "apply_liger_kernel_to_falcon_h1",
136
177
  "apply_liger_kernel_to_gemma",
137
178
  "apply_liger_kernel_to_gemma2",
138
179
  "apply_liger_kernel_to_gemma3",
139
180
  "apply_liger_kernel_to_gemma3_text",
140
181
  "apply_liger_kernel_to_glm4",
182
+ "apply_liger_kernel_to_glm4v",
183
+ "apply_liger_kernel_to_glm4v_moe",
141
184
  "apply_liger_kernel_to_granite",
185
+ "apply_liger_kernel_to_internvl",
142
186
  "apply_liger_kernel_to_llama",
143
187
  "apply_liger_kernel_to_llava",
188
+ "apply_liger_kernel_to_llama4",
144
189
  "apply_liger_kernel_to_mistral",
145
190
  "apply_liger_kernel_to_mixtral",
146
191
  "apply_liger_kernel_to_mllama",
@@ -152,5 +197,10 @@ if _TRANSFORMERS_AVAILABLE:
152
197
  "apply_liger_kernel_to_qwen2_vl",
153
198
  "apply_liger_kernel_to_qwen3",
154
199
  "apply_liger_kernel_to_qwen3_moe",
200
+ "apply_liger_kernel_to_qwen3_next",
201
+ "apply_liger_kernel_to_qwen3_vl",
202
+ "apply_liger_kernel_to_qwen3_vl_moe",
203
+ "apply_liger_kernel_to_smollm3",
204
+ "apply_liger_kernel_to_smolvlm",
155
205
  ]
156
206
  )