liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107181130__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  10. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  11. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  12. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  13. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  14. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +12 -3
  17. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  18. liger_kernel/ops/geglu.py +3 -2
  19. liger_kernel/ops/rms_norm.py +126 -49
  20. liger_kernel/ops/utils.py +12 -0
  21. liger_kernel/transformers/__init__.py +3 -0
  22. liger_kernel/transformers/auto_model.py +21 -0
  23. liger_kernel/transformers/cross_entropy.py +1 -1
  24. liger_kernel/transformers/dyt.py +1 -1
  25. liger_kernel/transformers/experimental/embedding.py +1 -1
  26. liger_kernel/transformers/functional.py +20 -20
  27. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  29. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  30. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  31. liger_kernel/transformers/geglu.py +1 -1
  32. liger_kernel/transformers/group_norm.py +1 -1
  33. liger_kernel/transformers/grpo_loss.py +1 -1
  34. liger_kernel/transformers/jsd.py +1 -1
  35. liger_kernel/transformers/kl_div.py +1 -1
  36. liger_kernel/transformers/layer_norm.py +1 -1
  37. liger_kernel/transformers/llama4_rope.py +1 -1
  38. liger_kernel/transformers/model/gemma3.py +1 -0
  39. liger_kernel/transformers/model/gpt_oss.py +211 -0
  40. liger_kernel/transformers/model/paligemma.py +1 -0
  41. liger_kernel/transformers/monkey_patch.py +118 -39
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +1 -1
  44. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  45. liger_kernel/transformers/rms_norm.py +8 -3
  46. liger_kernel/transformers/rope.py +28 -27
  47. liger_kernel/transformers/softmax.py +1 -1
  48. liger_kernel/transformers/sparsemax.py +1 -1
  49. liger_kernel/transformers/swiglu.py +1 -1
  50. liger_kernel/transformers/tiled_mlp.py +3 -3
  51. liger_kernel/transformers/tvd.py +1 -1
  52. liger_kernel/utils.py +27 -0
  53. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/METADATA +9 -3
  54. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/RECORD +58 -46
  55. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/LICENSE +0 -0
  56. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/NOTICE +0 -0
  57. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/WHEEL +0 -0
  58. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,7 @@ def _rms_norm_forward_kernel(
54
54
  eps,
55
55
  offset,
56
56
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
57
+ elementwise_affine: tl.constexpr,
57
58
  BLOCK_SIZE: tl.constexpr,
58
59
  ):
59
60
  """
@@ -75,7 +76,8 @@ def _rms_norm_forward_kernel(
75
76
 
76
77
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
77
78
  X_row_dtype = X_row.dtype
78
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79
+ if elementwise_affine:
80
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79
81
 
80
82
  # On Llama, only rstd is computed on fp32
81
83
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -83,7 +85,8 @@ def _rms_norm_forward_kernel(
83
85
 
84
86
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
85
87
  if casting_mode == _CASTING_MODE_GEMMA:
86
- W_row = W_row.to(tl.float32)
88
+ if elementwise_affine:
89
+ W_row = W_row.to(tl.float32)
87
90
  X_row = X_row.to(tl.float32)
88
91
 
89
92
  if casting_mode == _CASTING_MODE_NONE:
@@ -104,7 +107,10 @@ def _rms_norm_forward_kernel(
104
107
  if casting_mode == _CASTING_MODE_LLAMA:
105
108
  X_row = X_row.to(X_row_dtype)
106
109
 
107
- Y_row = X_row * (offset + W_row)
110
+ if elementwise_affine:
111
+ Y_row = X_row * (offset + W_row)
112
+ else:
113
+ Y_row = X_row
108
114
 
109
115
  if casting_mode == _CASTING_MODE_GEMMA:
110
116
  Y_row = Y_row.to(X_row_dtype)
@@ -130,8 +136,9 @@ def _rms_norm_backward_kernel(
130
136
  n_rows,
131
137
  n_cols,
132
138
  offset,
133
- rows_per_program: tl.constexpr,
139
+ rows_per_program,
134
140
  casting_mode: tl.constexpr,
141
+ elementwise_affine: tl.constexpr,
135
142
  BLOCK_SIZE: tl.constexpr,
136
143
  ):
137
144
  """
@@ -145,7 +152,8 @@ def _rms_norm_backward_kernel(
145
152
  col_offsets = tl.arange(0, BLOCK_SIZE)
146
153
  mask = col_offsets < n_cols
147
154
 
148
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
155
+ if elementwise_affine:
156
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
149
157
 
150
158
  dY_ptr += row_start * dY_row_stride
151
159
  dX_ptr += row_start * dX_row_stride
@@ -153,8 +161,9 @@ def _rms_norm_backward_kernel(
153
161
  X_ptr += row_start * X_row_stride
154
162
  RSTD_ptr += row_start
155
163
 
156
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
157
- W_row = W_row + offset
164
+ if elementwise_affine:
165
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
166
+ W_row = W_row + offset
158
167
 
159
168
  for _ in range(row_start, row_end):
160
169
  dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
@@ -167,24 +176,34 @@ def _rms_norm_backward_kernel(
167
176
 
168
177
  # Different bacward graphs for different casting modes
169
178
  if casting_mode == _CASTING_MODE_LLAMA:
170
- m = (dY_row * W_row).to(tl.float32)
179
+ if elementwise_affine:
180
+ m = (dY_row * W_row).to(tl.float32)
181
+ else:
182
+ m = dY_row.to(tl.float32)
171
183
 
172
184
  elif casting_mode == _CASTING_MODE_GEMMA:
173
185
  dY_row = dY_row.to(tl.float32)
174
- m = dY_row * W_row
186
+ if elementwise_affine:
187
+ m = dY_row * W_row
188
+ else:
189
+ m = dY_row
175
190
  else:
176
- m = dY_row * W_row
191
+ if elementwise_affine:
192
+ m = dY_row * W_row
193
+ else:
194
+ m = dY_row
177
195
 
178
196
  dX_row = rstd_row * m
179
197
 
180
198
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
181
199
 
182
- # calculate the gradient of W
183
- if casting_mode == _CASTING_MODE_LLAMA:
184
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
185
- else:
186
- # here X_row is already in fp32 (see previous if block)
187
- dW_row += dY_row * (X_row * rstd_row)
200
+ if elementwise_affine:
201
+ # calculate the gradient of W
202
+ if casting_mode == _CASTING_MODE_LLAMA:
203
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
204
+ else:
205
+ # here X_row is already in fp32 (see previous if block)
206
+ dW_row += dY_row * (X_row * rstd_row)
188
207
 
189
208
  tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
190
209
 
@@ -193,7 +212,8 @@ def _rms_norm_backward_kernel(
193
212
  X_ptr += X_row_stride
194
213
  RSTD_ptr += RSTD_row_stride
195
214
 
196
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
215
+ if elementwise_affine:
216
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
197
217
 
198
218
 
199
219
  @triton.jit
@@ -211,6 +231,7 @@ def _block_rms_norm_forward_kernel(
211
231
  eps,
212
232
  offset,
213
233
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
234
+ elementwise_affine: tl.constexpr,
214
235
  BLOCK_SIZE: tl.constexpr,
215
236
  BLOCK_ROW: tl.constexpr,
216
237
  ):
@@ -234,7 +255,8 @@ def _block_rms_norm_forward_kernel(
234
255
  other=0,
235
256
  )
236
257
  X_row_dtype = X_row.dtype
237
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
258
+ if elementwise_affine:
259
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
238
260
 
239
261
  # On Llama, only rstd is computed on fp32
240
262
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -242,7 +264,8 @@ def _block_rms_norm_forward_kernel(
242
264
 
243
265
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
244
266
  if casting_mode == _CASTING_MODE_GEMMA:
245
- W_row = W_row.to(tl.float32)
267
+ if elementwise_affine:
268
+ W_row = W_row.to(tl.float32)
246
269
  X_row = X_row.to(tl.float32)
247
270
 
248
271
  if casting_mode == _CASTING_MODE_NONE:
@@ -263,7 +286,10 @@ def _block_rms_norm_forward_kernel(
263
286
  if casting_mode == _CASTING_MODE_LLAMA:
264
287
  X_row = X_row.to(X_row_dtype)
265
288
 
266
- Y_row = X_row * (offset + W_row)[None, :]
289
+ if elementwise_affine:
290
+ Y_row = X_row * (offset + W_row)[None, :]
291
+ else:
292
+ Y_row = X_row
267
293
 
268
294
  if casting_mode == _CASTING_MODE_GEMMA:
269
295
  Y_row = Y_row.to(X_row_dtype)
@@ -293,8 +319,8 @@ def _block_rms_norm_backward_kernel(
293
319
  n_rows,
294
320
  n_cols,
295
321
  offset,
296
- rows_per_program: tl.constexpr,
297
322
  casting_mode: tl.constexpr,
323
+ elementwise_affine: tl.constexpr,
298
324
  BLOCK_SIZE: tl.constexpr,
299
325
  BLOCK_ROW: tl.constexpr,
300
326
  ):
@@ -309,10 +335,11 @@ def _block_rms_norm_backward_kernel(
309
335
  col_offsets = tl.arange(0, BLOCK_SIZE)
310
336
  col_mask = col_offsets < n_cols
311
337
 
312
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
338
+ if elementwise_affine:
339
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
313
340
 
314
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
315
- W_row = W_row + offset
341
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
342
+ W_row = W_row + offset
316
343
 
317
344
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
318
345
  row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -335,13 +362,22 @@ def _block_rms_norm_backward_kernel(
335
362
 
336
363
  # Different bacward graphs for different casting modes
337
364
  if casting_mode == _CASTING_MODE_LLAMA:
338
- m = (dY_row * W_row[None, :]).to(tl.float32)
365
+ if elementwise_affine:
366
+ m = (dY_row * W_row[None, :]).to(tl.float32)
367
+ else:
368
+ m = dY_row.to(tl.float32)
339
369
 
340
370
  elif casting_mode == _CASTING_MODE_GEMMA:
341
371
  dY_row = dY_row.to(tl.float32)
342
- m = dY_row * W_row[None, :]
372
+ if elementwise_affine:
373
+ m = dY_row * W_row[None, :]
374
+ else:
375
+ m = dY_row
343
376
  else:
344
- m = dY_row * W_row[None, :]
377
+ if elementwise_affine:
378
+ m = dY_row * W_row[None, :]
379
+ else:
380
+ m = dY_row
345
381
 
346
382
  dX_row = rstd_row[:, None] * m
347
383
 
@@ -349,13 +385,13 @@ def _block_rms_norm_backward_kernel(
349
385
  -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
350
386
  )
351
387
 
352
- # calculate the gradient of W
353
- if casting_mode == _CASTING_MODE_LLAMA:
354
- # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
- dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
356
- else:
357
- # here X_row is already in fp32 (see previous if block)
358
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
388
+ if elementwise_affine:
389
+ if casting_mode == _CASTING_MODE_LLAMA:
390
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
391
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
392
+ else:
393
+ # here X_row is already in fp32 (see previous if block)
394
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
359
395
 
360
396
  tl.store(
361
397
  dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
@@ -363,7 +399,8 @@ def _block_rms_norm_backward_kernel(
363
399
  mask=row_mask[:, None] & col_mask[None, :],
364
400
  )
365
401
 
366
- tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
402
+ if elementwise_affine:
403
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
367
404
 
368
405
 
369
406
  _str_to_casting_mode = {
@@ -392,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
392
429
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
393
430
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
394
431
 
395
- # Check constraints.
396
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432
+ if W is not None:
433
+ # Check constraints.
434
+ assert X.shape[1] == W.shape[0], (
435
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
436
+ )
437
+ elementwise_affine = True
438
+ else:
439
+ elementwise_affine = False
397
440
 
398
441
  # XPU-specific optimization
399
442
  kernel_args = {}
@@ -406,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
406
449
  X,
407
450
  X.stride(0),
408
451
  W,
409
- W.stride(0),
452
+ W.stride(0) if elementwise_affine else 0,
410
453
  RSTD,
411
454
  RSTD.stride(0),
412
455
  n_cols,
413
456
  eps,
414
457
  offset,
415
458
  casting_mode,
459
+ elementwise_affine=elementwise_affine,
416
460
  BLOCK_SIZE=BLOCK_SIZE,
417
461
  num_warps=num_warps,
418
462
  **kernel_args, # XPU-specific optimization
@@ -426,7 +470,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
426
470
  X,
427
471
  X.stride(0),
428
472
  W,
429
- W.stride(0),
473
+ W.stride(0) if elementwise_affine else 0,
430
474
  RSTD,
431
475
  RSTD.stride(0),
432
476
  n_rows,
@@ -434,6 +478,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
434
478
  eps,
435
479
  offset,
436
480
  casting_mode,
481
+ elementwise_affine=elementwise_affine,
437
482
  BLOCK_SIZE=BLOCK_SIZE,
438
483
  num_warps=num_warps,
439
484
  **kernel_args, # XPU-specific optimization
@@ -455,8 +500,13 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
455
500
  elif X.device.type == "npu":
456
501
  sm_count = get_npu_multi_processor_count()
457
502
 
458
- # fp32 for numerical stability especially.
459
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
503
+ if W is not None:
504
+ # fp32 for numerical stability especially.
505
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
506
+ elementwise_affine = True
507
+ else:
508
+ _dW = None
509
+ elementwise_affine = False
460
510
 
461
511
  if n_cols > BLOCK_SIZE:
462
512
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -483,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
483
533
  X.stride(0),
484
534
  torch_to_triton_dtype[X.dtype],
485
535
  W,
486
- W.stride(0),
536
+ W.stride(0) if elementwise_affine else 0,
487
537
  RSTD,
488
538
  RSTD.stride(0),
489
539
  _dW,
490
- _dW.stride(0),
540
+ _dW.stride(0) if elementwise_affine else 0,
491
541
  n_rows,
492
542
  n_cols,
493
543
  offset,
494
544
  rows_per_program,
495
545
  casting_mode,
546
+ elementwise_affine=elementwise_affine,
496
547
  BLOCK_SIZE=BLOCK_SIZE,
497
548
  num_warps=num_warps,
498
549
  **kernel_args, # XPU-specific optimization
@@ -509,22 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
509
560
  X.stride(0),
510
561
  torch_to_triton_dtype[X.dtype],
511
562
  W,
512
- W.stride(0),
563
+ W.stride(0) if elementwise_affine else 0,
513
564
  RSTD,
514
565
  RSTD.stride(0),
515
566
  _dW,
516
- _dW.stride(0),
567
+ _dW.stride(0) if elementwise_affine else 0,
517
568
  n_rows,
518
569
  n_cols,
519
570
  offset,
520
- rows_per_program,
521
571
  casting_mode,
572
+ elementwise_affine=elementwise_affine,
522
573
  BLOCK_SIZE=BLOCK_SIZE,
523
574
  num_warps=num_warps,
524
575
  **kernel_args, # XPU-specific optimization
525
576
  )
526
577
  dX = dX.view(*shape)
527
- dW = _dW.sum(dim=0).to(W.dtype)
578
+
579
+ if elementwise_affine:
580
+ dW = _dW.sum(dim=0).to(W.dtype)
581
+ else:
582
+ dW = None
528
583
 
529
584
  return dX, dW
530
585
 
@@ -558,6 +613,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
558
613
  X: (B, T, H) or (BxT, H)
559
614
  W: (H,)
560
615
  """
616
+ if isinstance(X, torch.distributed.tensor.DTensor):
617
+ # Input tensor is output of a tensor parallel module and
618
+ # needs to be gathered to a local tensor to compute
619
+ # RMSE layer norm on each TP worker.
620
+ # TODO: support CP.
621
+ X = X.full_tensor()
622
+
561
623
  Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
562
624
  ctx.offset = offset
563
625
  ctx.casting_mode = casting_mode
@@ -565,7 +627,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
565
627
  ctx.row_mode = row_mode
566
628
  ctx.BLOCK_SIZE = BLOCK_SIZE
567
629
  ctx.num_warps = num_warps
568
- ctx.save_for_backward(X, W, RSTD)
630
+ ctx.elementwise_affine = W is not None
631
+ if W is not None:
632
+ ctx.save_for_backward(X, W, RSTD)
633
+ else:
634
+ ctx.save_for_backward(X, RSTD)
569
635
  return Y
570
636
 
571
637
  @staticmethod
@@ -574,7 +640,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
574
640
  """
575
641
  Y: (B, T, H) or (BxT, H)
576
642
  """
577
- X, W, RSTD = ctx.saved_tensors
643
+ if ctx.elementwise_affine:
644
+ X, W, RSTD = ctx.saved_tensors
645
+ else:
646
+ X, RSTD = ctx.saved_tensors
647
+ W = None
648
+
649
+ if isinstance(dY, torch.distributed.tensor.DTensor):
650
+ # Gradients are output of a tensor parallel module and
651
+ # needs to be gathered to a local tensor for computing RMSE layer.
652
+ # TODO: support CP.
653
+ dY = dY.full_tensor()
654
+
578
655
  dX, dW = rms_norm_backward(
579
656
  dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
580
657
  )
liger_kernel/ops/utils.py CHANGED
@@ -127,3 +127,15 @@ def element_mul_kernel(
127
127
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
128
128
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
129
129
  tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
130
+
131
+
132
+ def get_npu_core_count(default: int = 20) -> int:
133
+ """Return NPU vector core count.
134
+ Fallback to `default` if Triton runtime or NPU device is unavailable.
135
+ """
136
+ try:
137
+ utils = triton.runtime.driver.active.utils
138
+ props = utils.get_device_properties(0)
139
+ return int(props.get("num_vectorcore", default))
140
+ except Exception:
141
+ return default
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
41
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
42
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
44
45
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
46
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
47
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
@@ -110,6 +111,7 @@ def __getattr__(name: str):
110
111
  "apply_liger_kernel_to_glm4",
111
112
  "apply_liger_kernel_to_glm4v",
112
113
  "apply_liger_kernel_to_glm4v_moe",
114
+ "apply_liger_kernel_to_gpt_oss",
113
115
  "apply_liger_kernel_to_granite",
114
116
  "apply_liger_kernel_to_internvl",
115
117
  "apply_liger_kernel_to_llama",
@@ -187,6 +189,7 @@ if _TRANSFORMERS_AVAILABLE:
187
189
  "apply_liger_kernel_to_glm4",
188
190
  "apply_liger_kernel_to_glm4v",
189
191
  "apply_liger_kernel_to_glm4v_moe",
192
+ "apply_liger_kernel_to_gpt_oss",
190
193
  "apply_liger_kernel_to_granite",
191
194
  "apply_liger_kernel_to_internvl",
192
195
  "apply_liger_kernel_to_llama",
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import logging
2
3
 
3
4
  from transformers import AutoConfig
4
5
  from transformers import AutoModelForCausalLM
@@ -6,6 +7,8 @@ from transformers import AutoModelForCausalLM
6
7
  from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
8
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  def _get_model_config(model_dir, **model_init_kwargs):
11
14
  config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
@@ -36,3 +39,21 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
36
39
  applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
37
40
 
38
41
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
42
+
43
+ @classmethod
44
+ def from_config(cls, config, **kwargs):
45
+ model_type = getattr(config, "model_type", None)
46
+ if not model_type:
47
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
48
+ return
49
+ model_type = config.model_type
50
+
51
+ _apply_liger_kernel(model_type, **kwargs)
52
+
53
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
54
+ # model initialization errors otherwise
55
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
56
+ apply_fn_signature = inspect.signature(apply_fn)
57
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
58
+
59
+ return super().from_config(config, **applicable_kwargs)
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
5
+ from liger_kernel.ops import LigerCrossEntropyFunction
6
6
  from liger_kernel.transformers.functional import CrossEntropyOutput
7
7
 
8
8
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.dyt import LigerDyTFunction
4
+ from liger_kernel.ops import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
@@ -3,7 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
6
+ from liger_kernel.ops import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
@@ -3,26 +3,26 @@ from typing import Optional
3
3
 
4
4
  import torch
5
5
 
6
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
7
- from liger_kernel.ops.dyt import LigerDyTFunction
8
- from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
9
- from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
10
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
11
- from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
12
- from liger_kernel.ops.geglu import LigerGELUMulFunction
13
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
14
- from liger_kernel.ops.jsd import LigerJSDFunction
15
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
16
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
17
- from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
18
- from liger_kernel.ops.poly_norm import LigerPolyNormFunction
19
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
20
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
21
- from liger_kernel.ops.rope import LigerRopeFunction
22
- from liger_kernel.ops.softmax import LigerSoftmaxFunction
23
- from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
24
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
25
- from liger_kernel.ops.tvd import LigerTVDLossFunction
6
+ from liger_kernel.ops import LigerCrossEntropyFunction
7
+ from liger_kernel.ops import LigerDyTFunction
8
+ from liger_kernel.ops import LigerFusedAddRMSNormFunction
9
+ from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
10
+ from liger_kernel.ops import LigerFusedLinearJSDFunction
11
+ from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
12
+ from liger_kernel.ops import LigerGELUMulFunction
13
+ from liger_kernel.ops import LigerGroupNormFunction
14
+ from liger_kernel.ops import LigerJSDFunction
15
+ from liger_kernel.ops import LigerKLDivLossFunction
16
+ from liger_kernel.ops import LigerLayerNormFunction
17
+ from liger_kernel.ops import LigerMultiTokenAttentionFunction
18
+ from liger_kernel.ops import LigerPolyNormFunction
19
+ from liger_kernel.ops import LigerQwen2VLMRopeFunction
20
+ from liger_kernel.ops import LigerRMSNormFunction
21
+ from liger_kernel.ops import LigerRopeFunction
22
+ from liger_kernel.ops import LigerSiLUMulFunction
23
+ from liger_kernel.ops import LigerSoftmaxFunction
24
+ from liger_kernel.ops import LigerSparsemaxFunction
25
+ from liger_kernel.ops import LigerTVDLossFunction
26
26
 
27
27
 
28
28
  @dataclass
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
4
+ from liger_kernel.ops import LigerFusedAddRMSNormFunction
5
5
 
6
6
 
7
7
  class LigerFusedAddRMSNorm(nn.Module):
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
+ from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
6
6
  from liger_kernel.transformers.functional import CrossEntropyOutput
7
7
 
8
8
 
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
5
+ from liger_kernel.ops import LigerFusedLinearJSDFunction
6
6
 
7
7
 
8
8
  class LigerFusedLinearJSD(torch.nn.Module):
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import torch
6
6
  import torch.nn as nn
7
7
 
8
- from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
8
+ from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
9
9
 
10
10
 
11
11
  class LigerFusedNeighborhoodAttention(nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.geglu import LigerGELUMulFunction
3
+ from liger_kernel.ops import LigerGELUMulFunction
4
4
 
5
5
 
6
6
  class LigerGEGLUMLP(nn.Module):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
4
+ from liger_kernel.ops import LigerGroupNormFunction
5
5
 
6
6
 
7
7
  class LigerGroupNorm(nn.Module):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
- from liger_kernel.ops.grpo_loss import GrpoLossFunction
4
+ from liger_kernel.ops import GrpoLossFunction
5
5
 
6
6
 
7
7
  def triton_grpo_loss(
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.jsd import LigerJSDFunction
5
+ from liger_kernel.ops import LigerJSDFunction
6
6
 
7
7
 
8
8
  class LigerJSD(torch.nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
3
+ from liger_kernel.ops import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
4
+ from liger_kernel.ops import LigerLayerNormFunction
5
5
 
6
6
 
7
7
  class LigerLayerNorm(nn.Module):
@@ -5,7 +5,7 @@ Supports both text and vision RoPE variants with fused operations for optimal pe
5
5
 
6
6
  import torch
7
7
 
8
- from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
8
+ from liger_kernel.ops import LigerLlama4RopeFunction
9
9
 
10
10
 
11
11
  def liger_llama4_text_rotary_pos_emb(