liger-kernel 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +59 -9
  7. liger_kernel/ops/fused_linear_cross_entropy.py +30 -4
  8. liger_kernel/ops/grpo_loss.py +3 -1
  9. liger_kernel/ops/layer_norm.py +84 -65
  10. liger_kernel/ops/tiled_mlp.py +136 -0
  11. liger_kernel/transformers/__init__.py +19 -0
  12. liger_kernel/transformers/cross_entropy.py +8 -3
  13. liger_kernel/transformers/functional.py +24 -6
  14. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  15. liger_kernel/transformers/grpo_loss.py +56 -1
  16. liger_kernel/transformers/model/falcon_h1.py +19 -5
  17. liger_kernel/transformers/model/gemma.py +17 -6
  18. liger_kernel/transformers/model/gemma2.py +14 -5
  19. liger_kernel/transformers/model/gemma3.py +25 -12
  20. liger_kernel/transformers/model/glm4.py +16 -4
  21. liger_kernel/transformers/model/glm4v.py +16 -4
  22. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  23. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  24. liger_kernel/transformers/model/internvl.py +12 -5
  25. liger_kernel/transformers/model/llama.py +14 -5
  26. liger_kernel/transformers/model/llama4.py +16 -4
  27. liger_kernel/transformers/model/llava.py +12 -4
  28. liger_kernel/transformers/model/loss_utils.py +31 -3
  29. liger_kernel/transformers/model/mistral.py +15 -6
  30. liger_kernel/transformers/model/mixtral.py +16 -7
  31. liger_kernel/transformers/model/mllama.py +12 -4
  32. liger_kernel/transformers/model/olmo2.py +16 -4
  33. liger_kernel/transformers/model/olmo3.py +142 -0
  34. liger_kernel/transformers/model/output_classes.py +147 -0
  35. liger_kernel/transformers/model/paligemma.py +22 -5
  36. liger_kernel/transformers/model/phi3.py +14 -7
  37. liger_kernel/transformers/model/qwen2.py +16 -3
  38. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  39. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  40. liger_kernel/transformers/model/qwen3.py +20 -5
  41. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  42. liger_kernel/transformers/model/qwen3_next.py +17 -5
  43. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  44. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  45. liger_kernel/transformers/model/smollm3.py +15 -6
  46. liger_kernel/transformers/monkey_patch.py +398 -20
  47. liger_kernel/transformers/rope.py +43 -0
  48. liger_kernel/transformers/swiglu.py +17 -0
  49. liger_kernel/transformers/tiled_mlp.py +133 -0
  50. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +4 -1
  51. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/RECORD +55 -48
  52. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
  53. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
@@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
29
  use_token_scaling=False,
30
+ return_token_accuracy=False,
30
31
  ):
31
32
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
32
36
  device = _input.device
33
37
 
34
38
  input_requires_grad = _input.requires_grad
@@ -58,9 +62,13 @@ def fused_linear_cross_entropy_forward(
58
62
  else:
59
63
  grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
60
64
  grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
65
+ else:
66
+ grad_weight = None
67
+ grad_bias = None
61
68
 
62
69
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
63
70
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
64
72
 
65
73
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
66
74
  target_mask = target != ignore_index
@@ -126,6 +134,7 @@ def fused_linear_cross_entropy_forward(
126
134
  # unreduced loss
127
135
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
128
136
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
129
138
 
130
139
  # ensure _input and target are contiguous
131
140
  logits_chunk = logits_chunk.contiguous()
@@ -141,6 +150,10 @@ def fused_linear_cross_entropy_forward(
141
150
  loss_ptr=loss_1d_slice,
142
151
  z_loss_ptr=z_loss_1d_slice,
143
152
  loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
144
157
  n_cols=V,
145
158
  n_non_ignore=total_n_non_ignore,
146
159
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -151,6 +164,7 @@ def fused_linear_cross_entropy_forward(
151
164
  reduction=reduction,
152
165
  softcap=softcap,
153
166
  RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
154
168
  HAS_WEIGHT=True if ce_weight is not None else False,
155
169
  HAS_SOFTCAPPING=True if softcap is not None else False,
156
170
  HAS_GRADIENTS=input_requires_grad,
@@ -167,6 +181,8 @@ def fused_linear_cross_entropy_forward(
167
181
  loss_1d[start_idx:end_idx] = loss_1d_slice
168
182
  if return_z_loss:
169
183
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
170
186
  grad_logits_chunk = logits_chunk # chunk_size x V
171
187
 
172
188
  # Apply token scaling to gradients if requested
@@ -198,15 +214,18 @@ def fused_linear_cross_entropy_forward(
198
214
  # Return per-token losses
199
215
  loss = loss_1d
200
216
  z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
201
218
  else:
202
219
  loss = torch.sum(loss_1d)
203
220
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
204
223
 
205
224
  # Cast back to original dtype
206
225
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
207
226
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
208
227
 
209
- return loss, z_loss, grad_input, grad_weight, grad_bias
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
210
229
 
211
230
 
212
231
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -274,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
274
293
  return_z_loss: bool = False,
275
294
  accum_dtype=None,
276
295
  use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
277
297
  ):
278
298
  """
279
299
  Fusing the last linear layer with cross-entropy loss
@@ -297,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
297
317
  use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
298
318
  When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
299
319
  Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
300
321
  """
301
322
 
302
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
303
324
  _input=_input,
304
325
  weight=weight,
305
326
  target=target,
@@ -313,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
313
334
  return_z_loss=return_z_loss,
314
335
  accum_dtype=accum_dtype,
315
336
  use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
316
338
  )
317
339
  # downcast to dtype and store for backward
318
340
  ctx.save_for_backward(
@@ -321,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
321
343
  grad_bias.detach() if bias is not None else None,
322
344
  )
323
345
  ctx.return_z_loss = return_z_loss
324
- return loss, z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
325
348
 
326
349
  @staticmethod
327
350
  @amp_custom_bwd
328
- def backward(ctx, grad_output, grad_output2):
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
329
352
  if ctx.return_z_loss:
330
353
  del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
331
356
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
332
357
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
333
358
  grad_output, grad_input, grad_weight, grad_bias
@@ -346,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
346
371
  None,
347
372
  None,
348
373
  None, # use_token_scaling
374
+ None, # return_token_accuracy
349
375
  )
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import operator
2
3
 
3
4
  import torch
@@ -85,68 +86,87 @@ def _layer_norm_forward_kernel(
85
86
  @triton.jit
86
87
  def _layer_norm_backward_kernel(
87
88
  X_ptr, # pointer to input, shape (n_rows, n_cols)
89
+ stride_x, # stride of each row in input
88
90
  W_ptr, # pointer to weights, shape (n_cols,)
89
91
  Mean_ptr, # pointer to mean, shape (n_rows,)
92
+ stride_mean, # stride of each row in mean
90
93
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
94
+ stride_rstd, # stride of each row in rstd
91
95
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
96
+ stride_dx, # stride of each row in input grad
92
97
  DW_ptr, # pointer to weights grad, shape (n_cols,)
98
+ stride_dw, # stride of each row in weights grad
93
99
  DB_ptr, # pointer to bias grad, shape (n_cols,)
100
+ stride_db, # stride of each row in bias grad
94
101
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
95
- stride_x, # stride of each row in input
96
- stride_dx, # stride of each row in input grad
97
102
  stride_dy, # stride of each row in output grad
103
+ n_rows,
98
104
  n_cols,
105
+ rows_per_program: tl.constexpr,
99
106
  BLOCK_SIZE: tl.constexpr,
100
- dtype: tl.constexpr,
101
- atomic_dtype: tl.constexpr,
102
107
  ):
103
108
  """
104
109
  References:
105
110
  https://arxiv.org/abs/1607.06450
106
111
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
107
112
  """
108
- row_idx = tl.program_id(0).to(tl.int64)
113
+ row_block_id = tl.program_id(0).to(tl.int64)
114
+ row_start = row_block_id * rows_per_program
115
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
116
  cols = tl.arange(0, BLOCK_SIZE)
110
117
  mask = cols < n_cols
111
118
 
119
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
120
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121
+
112
122
  # Pre-load weights once (same optimization as forward pass)
113
123
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114
124
  w_f32 = w.to(tl.float32)
115
125
 
116
126
  # Calculate pointers for this specific row
117
- row_X_ptr = X_ptr + row_idx * stride_x
118
- row_DX_ptr = DX_ptr + row_idx * stride_dx
119
- row_DY_ptr = DY_ptr + row_idx * stride_dy
120
- row_Mean_ptr = Mean_ptr + row_idx
121
- row_RSTD_ptr = RSTD_ptr + row_idx
122
-
123
- # Load data for this row
124
- x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125
- dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126
- mean = tl.load(row_Mean_ptr)
127
- rstd = tl.load(row_RSTD_ptr)
128
-
129
- # Convert to fp32 for numerical stability
130
- x_f32 = x.to(tl.float32)
131
- dy_f32 = dy.to(tl.float32)
132
- mean_f32 = mean.to(tl.float32)
133
- rstd_f32 = rstd.to(tl.float32)
134
-
135
- # Compute backward pass for this row
136
- x_hat = (x_f32 - mean_f32) * rstd_f32
137
- wdy = w_f32 * dy_f32
138
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
- c2 = tl.sum(wdy, axis=0) / n_cols
140
- dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141
-
142
- # Store input gradient
143
- tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144
-
145
- # Accumulate weight and bias gradients using atomic operations
146
- dw = dy_f32 * x_hat
147
- db = dy_f32
148
- tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149
- tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
127
+ row_X_ptr = X_ptr + row_start * stride_x
128
+ row_DX_ptr = DX_ptr + row_start * stride_dx
129
+ row_DY_ptr = DY_ptr + row_start * stride_dy
130
+ row_Mean_ptr = Mean_ptr + row_start
131
+ row_RSTD_ptr = RSTD_ptr + row_start
132
+
133
+ for _ in range(row_start, row_end):
134
+ # Load data for this row
135
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
136
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
137
+ mean = tl.load(row_Mean_ptr)
138
+ rstd = tl.load(row_RSTD_ptr)
139
+
140
+ # Convert to fp32 for numerical stability
141
+ x_f32 = x.to(tl.float32)
142
+ dy_f32 = dy.to(tl.float32)
143
+ mean_f32 = mean.to(tl.float32)
144
+ rstd_f32 = rstd.to(tl.float32)
145
+
146
+ # Compute backward pass for this row
147
+ x_hat = (x_f32 - mean_f32) * rstd_f32
148
+ wdy = w_f32 * dy_f32
149
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
150
+ c2 = tl.sum(wdy, axis=0) / n_cols
151
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
152
+
153
+ # Store input gradient
154
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
155
+
156
+ # Accumulate weight and bias gradients for this thread block's assigned rows
157
+ dw = dy_f32 * x_hat
158
+ db = dy_f32
159
+ dW_row += dw
160
+ db_row += db
161
+
162
+ row_X_ptr += stride_x
163
+ row_DX_ptr += stride_dx
164
+ row_DY_ptr += stride_dy
165
+ row_Mean_ptr += stride_mean
166
+ row_RSTD_ptr += stride_rstd
167
+
168
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
169
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
150
170
 
151
171
 
152
172
  def layer_norm_forward(X, W, B, eps):
@@ -228,31 +248,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
228
248
  dY = dY.view(-1, dim)
229
249
  n_rows, n_cols = dY.shape
230
250
 
231
- # Allocate gradient tensors
232
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
- # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234
- grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235
- DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236
- DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
251
+ sm_count = 1
252
+ if X.device.type == "cuda":
253
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
+ elif X.device.type == "xpu":
255
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256
+
257
+ # fp32 for numerical stability especially.
258
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
259
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
237
260
 
238
261
  # Calculate optimal block size and warp configuration
239
262
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
240
263
  if n_cols > BLOCK_SIZE:
241
264
  raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
265
+ rows_per_program = math.ceil(n_rows / sm_count)
266
+ grid = (sm_count,)
242
267
 
243
- # Determine dtype for triton operations
244
- triton_dtype = (
245
- tl.float32
246
- if X.dtype == torch.float32
247
- else tl.bfloat16
248
- if X.dtype == torch.bfloat16
249
- else tl.float16
250
- if X.dtype == torch.float16
251
- else tl.float32 # fallback
252
- )
253
-
254
- # Use float32 for atomic operations if bfloat16 is not supported
255
- atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
268
+ # Allocate gradient tensors
269
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
256
270
 
257
271
  kernel_args = {"num_warps": num_warps}
258
272
  # XPU-specific optimization
@@ -260,28 +274,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
260
274
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
261
275
 
262
276
  # Launch kernel with one thread block per row for optimal performance
263
- grid = (n_rows,)
264
277
  _layer_norm_backward_kernel[grid](
265
278
  X,
279
+ X.stride(0),
266
280
  W,
267
281
  Mean,
282
+ Mean.stride(0),
268
283
  RSTD,
284
+ RSTD.stride(0),
269
285
  DX,
270
- DW,
271
- DB,
272
- dY,
273
- X.stride(0),
274
286
  DX.stride(0),
287
+ _DW,
288
+ _DW.stride(0),
289
+ _DB,
290
+ _DB.stride(0),
291
+ dY,
275
292
  dY.stride(0),
293
+ n_rows,
276
294
  n_cols,
295
+ rows_per_program=rows_per_program,
277
296
  BLOCK_SIZE=BLOCK_SIZE,
278
- dtype=triton_dtype,
279
- atomic_dtype=atomic_dtype,
280
297
  **kernel_args,
281
298
  )
282
299
 
283
300
  DX = DX.view(*shape)
284
- return DX, DW.to(W.dtype), DB.to(W.dtype)
301
+ DW = _DW.sum(dim=0).to(W.dtype)
302
+ DB = _DB.sum(dim=0).to(B.dtype)
303
+ return DX, DW, DB
285
304
 
286
305
 
287
306
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
@@ -24,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
24
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
26
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
27
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
28
30
 
29
31
  # Static-only imports for IDEs and type checkers
@@ -40,6 +42,8 @@ if TYPE_CHECKING:
40
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
41
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
42
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
43
47
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
44
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
45
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -48,6 +52,7 @@ if TYPE_CHECKING:
48
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
49
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
50
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
51
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
52
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
53
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -56,6 +61,8 @@ if TYPE_CHECKING:
56
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
57
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58
63
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
65
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
59
66
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
60
67
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
61
68
 
@@ -112,6 +119,7 @@ def __getattr__(name: str):
112
119
  "apply_liger_kernel_to_mixtral",
113
120
  "apply_liger_kernel_to_mllama",
114
121
  "apply_liger_kernel_to_olmo2",
122
+ "apply_liger_kernel_to_olmo3",
115
123
  "apply_liger_kernel_to_paligemma",
116
124
  "apply_liger_kernel_to_phi3",
117
125
  "apply_liger_kernel_to_qwen2",
@@ -120,8 +128,12 @@ def __getattr__(name: str):
120
128
  "apply_liger_kernel_to_qwen3",
121
129
  "apply_liger_kernel_to_qwen3_moe",
122
130
  "apply_liger_kernel_to_qwen3_next",
131
+ "apply_liger_kernel_to_qwen3_vl",
132
+ "apply_liger_kernel_to_qwen3_vl_moe",
123
133
  "apply_liger_kernel_to_smollm3",
124
134
  "apply_liger_kernel_to_smolvlm",
135
+ "apply_liger_kernel_to_hunyuan_v1_dense",
136
+ "apply_liger_kernel_to_hunyuan_v1_moe",
125
137
  }
126
138
 
127
139
  if name in monkey_patch_symbols:
@@ -151,6 +163,8 @@ __all__ = [
151
163
  "LigerPhi3SwiGLUMLP",
152
164
  "LigerQwen3MoeSwiGLUMLP",
153
165
  "LigerSwiGLUMLP",
166
+ "LigerTiledGEGLUMLP",
167
+ "LigerTiledSwiGLUMLP",
154
168
  "LigerTVDLoss",
155
169
  "LigerKLDIVLoss",
156
170
  "LigerMultiTokenAttention",
@@ -182,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
182
196
  "apply_liger_kernel_to_mixtral",
183
197
  "apply_liger_kernel_to_mllama",
184
198
  "apply_liger_kernel_to_olmo2",
199
+ "apply_liger_kernel_to_olmo3",
185
200
  "apply_liger_kernel_to_paligemma",
186
201
  "apply_liger_kernel_to_phi3",
187
202
  "apply_liger_kernel_to_qwen2",
@@ -190,7 +205,11 @@ if _TRANSFORMERS_AVAILABLE:
190
205
  "apply_liger_kernel_to_qwen3",
191
206
  "apply_liger_kernel_to_qwen3_moe",
192
207
  "apply_liger_kernel_to_qwen3_next",
208
+ "apply_liger_kernel_to_qwen3_vl",
209
+ "apply_liger_kernel_to_qwen3_vl_moe",
193
210
  "apply_liger_kernel_to_smollm3",
194
211
  "apply_liger_kernel_to_smolvlm",
212
+ "apply_liger_kernel_to_hunyuan_v1_dense",
213
+ "apply_liger_kernel_to_hunyuan_v1_moe",
195
214
  ]
196
215
  )
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)