liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,12 @@ import triton.language as tl
20
20
  from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import get_npu_core_count
24
+ from liger_kernel.ops.utils import set_large_grf_mode
23
25
  from liger_kernel.ops.utils import torch_to_triton_dtype
26
+ from liger_kernel.utils import is_npu_available
24
27
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
28
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
29
  try:
27
30
  # typical import path with dispatch available
28
31
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
52
55
  eps,
53
56
  offset,
54
57
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
58
+ elementwise_affine: tl.constexpr,
55
59
  BLOCK_SIZE: tl.constexpr,
56
60
  ):
57
61
  """
@@ -67,13 +71,14 @@ def _rms_norm_forward_kernel(
67
71
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
72
  mask = col_offsets < n_cols
69
73
 
70
- Y_ptr += row_idx * Y_row_stride
71
- X_ptr += row_idx * X_row_stride
72
- RSTD_ptr += row_idx * RSTD_row_stride
74
+ y_base = Y_ptr + row_idx * Y_row_stride
75
+ x_base = X_ptr + row_idx * X_row_stride
76
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
73
77
 
74
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
78
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
75
79
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
80
+ if elementwise_affine:
81
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
82
 
78
83
  # On Llama, only rstd is computed on fp32
79
84
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
81
86
 
82
87
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
88
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
89
+ if elementwise_affine:
90
+ W_row = W_row.to(tl.float32)
85
91
  X_row = X_row.to(tl.float32)
86
92
 
87
93
  if casting_mode == _CASTING_MODE_NONE:
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
94
100
  # We can save time by caching rms with minimal memory overhead
95
101
  # because rms is much smaller compared to X_row, as rms is for each row.
96
102
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
- tl.store(RSTD_ptr, rstd)
103
+ tl.store(rstd_base, rstd)
98
104
 
99
105
  X_row = X_row * rstd
100
106
 
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
102
108
  if casting_mode == _CASTING_MODE_LLAMA:
103
109
  X_row = X_row.to(X_row_dtype)
104
110
 
105
- Y_row = X_row * (offset + W_row)
111
+ if elementwise_affine:
112
+ Y_row = X_row * (offset + W_row)
113
+ else:
114
+ Y_row = X_row
106
115
 
107
116
  if casting_mode == _CASTING_MODE_GEMMA:
108
117
  Y_row = Y_row.to(X_row_dtype)
109
118
 
110
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
119
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
111
120
 
112
121
 
113
122
  @triton.jit
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
128
137
  n_rows,
129
138
  n_cols,
130
139
  offset,
131
- rows_per_program: tl.constexpr,
140
+ rows_per_program,
132
141
  casting_mode: tl.constexpr,
142
+ elementwise_affine: tl.constexpr,
133
143
  BLOCK_SIZE: tl.constexpr,
134
144
  ):
135
145
  """
@@ -143,55 +153,63 @@ def _rms_norm_backward_kernel(
143
153
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
154
  mask = col_offsets < n_cols
145
155
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
156
+ if elementwise_affine:
157
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
158
 
148
- dY_ptr += row_start * dY_row_stride
149
- dX_ptr += row_start * dX_row_stride
159
+ if elementwise_affine:
160
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
161
+ W_row = W_row + offset
150
162
 
151
- X_ptr += row_start * X_row_stride
152
- RSTD_ptr += row_start
163
+ for row_idx in range(row_start, row_end):
164
+ dy_base = dY_ptr + row_idx * dY_row_stride
165
+ dx_base = dX_ptr + row_idx * dX_row_stride
153
166
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
167
+ x_base = X_ptr + row_idx * X_row_stride
168
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
156
169
 
157
- for _ in range(row_start, row_end):
158
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
170
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
171
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
160
172
 
161
173
  # Get cached rms
162
- rstd_row = tl.load(RSTD_ptr)
174
+ rstd_row = tl.load(rstd_base)
163
175
 
164
176
  X_row = X_row.to(tl.float32)
165
177
 
166
178
  # Different bacward graphs for different casting modes
167
179
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
180
+ if elementwise_affine:
181
+ m = (dY_row * W_row).to(tl.float32)
182
+ else:
183
+ m = dY_row.to(tl.float32)
169
184
 
170
185
  elif casting_mode == _CASTING_MODE_GEMMA:
171
186
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
187
+ if elementwise_affine:
188
+ m = dY_row * W_row
189
+ else:
190
+ m = dY_row
173
191
  else:
174
- m = dY_row * W_row
192
+ if elementwise_affine:
193
+ m = dY_row * W_row
194
+ else:
195
+ m = dY_row
175
196
 
176
197
  dX_row = rstd_row * m
177
198
 
178
199
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
200
 
180
- # calculate the gradient of W
181
- if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
- else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
201
+ if elementwise_affine:
202
+ # calculate the gradient of W
203
+ if casting_mode == _CASTING_MODE_LLAMA:
204
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
205
+ else:
206
+ # here X_row is already in fp32 (see previous if block)
207
+ dW_row += dY_row * (X_row * rstd_row)
186
208
 
187
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
209
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
188
210
 
189
- dY_ptr += dY_row_stride
190
- dX_ptr += dX_row_stride
191
- X_ptr += X_row_stride
192
- RSTD_ptr += RSTD_row_stride
193
-
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
211
+ if elementwise_affine:
212
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
213
 
196
214
 
197
215
  @triton.jit
@@ -209,6 +227,7 @@ def _block_rms_norm_forward_kernel(
209
227
  eps,
210
228
  offset,
211
229
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
230
+ elementwise_affine: tl.constexpr,
212
231
  BLOCK_SIZE: tl.constexpr,
213
232
  BLOCK_ROW: tl.constexpr,
214
233
  ):
@@ -232,7 +251,8 @@ def _block_rms_norm_forward_kernel(
232
251
  other=0,
233
252
  )
234
253
  X_row_dtype = X_row.dtype
235
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
254
+ if elementwise_affine:
255
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
256
 
237
257
  # On Llama, only rstd is computed on fp32
238
258
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -240,7 +260,8 @@ def _block_rms_norm_forward_kernel(
240
260
 
241
261
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
262
  if casting_mode == _CASTING_MODE_GEMMA:
243
- W_row = W_row.to(tl.float32)
263
+ if elementwise_affine:
264
+ W_row = W_row.to(tl.float32)
244
265
  X_row = X_row.to(tl.float32)
245
266
 
246
267
  if casting_mode == _CASTING_MODE_NONE:
@@ -261,7 +282,10 @@ def _block_rms_norm_forward_kernel(
261
282
  if casting_mode == _CASTING_MODE_LLAMA:
262
283
  X_row = X_row.to(X_row_dtype)
263
284
 
264
- Y_row = X_row * (offset + W_row)[None, :]
285
+ if elementwise_affine:
286
+ Y_row = X_row * (offset + W_row)[None, :]
287
+ else:
288
+ Y_row = X_row
265
289
 
266
290
  if casting_mode == _CASTING_MODE_GEMMA:
267
291
  Y_row = Y_row.to(X_row_dtype)
@@ -291,8 +315,8 @@ def _block_rms_norm_backward_kernel(
291
315
  n_rows,
292
316
  n_cols,
293
317
  offset,
294
- rows_per_program: tl.constexpr,
295
318
  casting_mode: tl.constexpr,
319
+ elementwise_affine: tl.constexpr,
296
320
  BLOCK_SIZE: tl.constexpr,
297
321
  BLOCK_ROW: tl.constexpr,
298
322
  ):
@@ -307,10 +331,11 @@ def _block_rms_norm_backward_kernel(
307
331
  col_offsets = tl.arange(0, BLOCK_SIZE)
308
332
  col_mask = col_offsets < n_cols
309
333
 
310
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
334
+ if elementwise_affine:
335
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
336
 
312
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
- W_row = W_row + offset
337
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
338
+ W_row = W_row + offset
314
339
 
315
340
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
341
  row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -333,13 +358,22 @@ def _block_rms_norm_backward_kernel(
333
358
 
334
359
  # Different bacward graphs for different casting modes
335
360
  if casting_mode == _CASTING_MODE_LLAMA:
336
- m = (dY_row * W_row[None, :]).to(tl.float32)
361
+ if elementwise_affine:
362
+ m = (dY_row * W_row[None, :]).to(tl.float32)
363
+ else:
364
+ m = dY_row.to(tl.float32)
337
365
 
338
366
  elif casting_mode == _CASTING_MODE_GEMMA:
339
367
  dY_row = dY_row.to(tl.float32)
340
- m = dY_row * W_row[None, :]
368
+ if elementwise_affine:
369
+ m = dY_row * W_row[None, :]
370
+ else:
371
+ m = dY_row
341
372
  else:
342
- m = dY_row * W_row[None, :]
373
+ if elementwise_affine:
374
+ m = dY_row * W_row[None, :]
375
+ else:
376
+ m = dY_row
343
377
 
344
378
  dX_row = rstd_row[:, None] * m
345
379
 
@@ -347,12 +381,13 @@ def _block_rms_norm_backward_kernel(
347
381
  -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
382
  )
349
383
 
350
- # calculate the gradient of W
351
- if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
- else:
354
- # here X_row is already in fp32 (see previous if block)
355
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
384
+ if elementwise_affine:
385
+ if casting_mode == _CASTING_MODE_LLAMA:
386
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
387
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
388
+ else:
389
+ # here X_row is already in fp32 (see previous if block)
390
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
391
 
357
392
  tl.store(
358
393
  dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
@@ -360,7 +395,8 @@ def _block_rms_norm_backward_kernel(
360
395
  mask=row_mask[:, None] & col_mask[None, :],
361
396
  )
362
397
 
363
- tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
398
+ if elementwise_affine:
399
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
400
 
365
401
 
366
402
  _str_to_casting_mode = {
@@ -389,13 +425,19 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
389
425
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
390
426
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
391
427
 
392
- # Check constraints.
393
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
428
+ if W is not None:
429
+ # Check constraints.
430
+ assert X.shape[1] == W.shape[0], (
431
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432
+ )
433
+ elementwise_affine = True
434
+ else:
435
+ elementwise_affine = False
394
436
 
395
437
  # XPU-specific optimization
396
438
  kernel_args = {}
397
439
  if X.device.type == "xpu":
398
- kernel_args["grf_mode"] = "large"
440
+ set_large_grf_mode(kernel_args)
399
441
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
400
442
  _rms_norm_forward_kernel[(n_rows,)](
401
443
  Y,
@@ -403,13 +445,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
403
445
  X,
404
446
  X.stride(0),
405
447
  W,
406
- W.stride(0),
448
+ W.stride(0) if elementwise_affine else 0,
407
449
  RSTD,
408
450
  RSTD.stride(0),
409
451
  n_cols,
410
452
  eps,
411
453
  offset,
412
454
  casting_mode,
455
+ elementwise_affine=elementwise_affine,
413
456
  BLOCK_SIZE=BLOCK_SIZE,
414
457
  num_warps=num_warps,
415
458
  **kernel_args, # XPU-specific optimization
@@ -423,7 +466,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
423
466
  X,
424
467
  X.stride(0),
425
468
  W,
426
- W.stride(0),
469
+ W.stride(0) if elementwise_affine else 0,
427
470
  RSTD,
428
471
  RSTD.stride(0),
429
472
  n_rows,
@@ -431,6 +474,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
431
474
  eps,
432
475
  offset,
433
476
  casting_mode,
477
+ elementwise_affine=elementwise_affine,
434
478
  BLOCK_SIZE=BLOCK_SIZE,
435
479
  num_warps=num_warps,
436
480
  **kernel_args, # XPU-specific optimization
@@ -449,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
493
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
494
  elif X.device.type == "xpu":
451
495
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
496
+ elif X.device.type == "npu":
497
+ sm_count = get_npu_core_count()
452
498
 
453
- # fp32 for numerical stability especially.
454
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
499
+ if W is not None:
500
+ # fp32 for numerical stability especially.
501
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
502
+ elementwise_affine = True
503
+ else:
504
+ _dW = None
505
+ elementwise_affine = False
455
506
 
456
507
  if n_cols > BLOCK_SIZE:
457
508
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -466,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
466
517
  # XPU-specific optimization
467
518
  kernel_args = {}
468
519
  if X.device.type == "xpu":
469
- kernel_args["grf_mode"] = "large"
520
+ set_large_grf_mode(kernel_args)
470
521
 
471
522
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
472
523
  _rms_norm_backward_kernel[grid](
@@ -478,16 +529,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
478
529
  X.stride(0),
479
530
  torch_to_triton_dtype[X.dtype],
480
531
  W,
481
- W.stride(0),
532
+ W.stride(0) if elementwise_affine else 0,
482
533
  RSTD,
483
534
  RSTD.stride(0),
484
535
  _dW,
485
- _dW.stride(0),
536
+ _dW.stride(0) if elementwise_affine else 0,
486
537
  n_rows,
487
538
  n_cols,
488
539
  offset,
489
540
  rows_per_program,
490
541
  casting_mode,
542
+ elementwise_affine=elementwise_affine,
491
543
  BLOCK_SIZE=BLOCK_SIZE,
492
544
  num_warps=num_warps,
493
545
  **kernel_args, # XPU-specific optimization
@@ -504,22 +556,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
504
556
  X.stride(0),
505
557
  torch_to_triton_dtype[X.dtype],
506
558
  W,
507
- W.stride(0),
559
+ W.stride(0) if elementwise_affine else 0,
508
560
  RSTD,
509
561
  RSTD.stride(0),
510
562
  _dW,
511
- _dW.stride(0),
563
+ _dW.stride(0) if elementwise_affine else 0,
512
564
  n_rows,
513
565
  n_cols,
514
566
  offset,
515
- rows_per_program,
516
567
  casting_mode,
568
+ elementwise_affine=elementwise_affine,
517
569
  BLOCK_SIZE=BLOCK_SIZE,
518
570
  num_warps=num_warps,
519
571
  **kernel_args, # XPU-specific optimization
520
572
  )
521
573
  dX = dX.view(*shape)
522
- dW = _dW.sum(dim=0).to(W.dtype)
574
+
575
+ if elementwise_affine:
576
+ dW = _dW.sum(dim=0).to(W.dtype)
577
+ else:
578
+ dW = None
523
579
 
524
580
  return dX, dW
525
581
 
@@ -553,6 +609,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
553
609
  X: (B, T, H) or (BxT, H)
554
610
  W: (H,)
555
611
  """
612
+ if isinstance(X, torch.distributed.tensor.DTensor):
613
+ # Input tensor is output of a tensor parallel module and
614
+ # needs to be gathered to a local tensor to compute
615
+ # RMSE layer norm on each TP worker.
616
+ # TODO: support CP.
617
+ X = X.full_tensor()
618
+
556
619
  Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
557
620
  ctx.offset = offset
558
621
  ctx.casting_mode = casting_mode
@@ -560,7 +623,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
560
623
  ctx.row_mode = row_mode
561
624
  ctx.BLOCK_SIZE = BLOCK_SIZE
562
625
  ctx.num_warps = num_warps
563
- ctx.save_for_backward(X, W, RSTD)
626
+ ctx.elementwise_affine = W is not None
627
+ if W is not None:
628
+ ctx.save_for_backward(X, W, RSTD)
629
+ else:
630
+ ctx.save_for_backward(X, RSTD)
564
631
  return Y
565
632
 
566
633
  @staticmethod
@@ -569,7 +636,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
569
636
  """
570
637
  Y: (B, T, H) or (BxT, H)
571
638
  """
572
- X, W, RSTD = ctx.saved_tensors
639
+ if ctx.elementwise_affine:
640
+ X, W, RSTD = ctx.saved_tensors
641
+ else:
642
+ X, RSTD = ctx.saved_tensors
643
+ W = None
644
+
645
+ if isinstance(dY, torch.distributed.tensor.DTensor):
646
+ # Gradients are output of a tensor parallel module and
647
+ # needs to be gathered to a local tensor for computing RMSE layer.
648
+ # TODO: support CP.
649
+ dY = dY.full_tensor()
650
+
573
651
  dX, dW = rms_norm_backward(
574
652
  dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
575
653
  )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -125,3 +127,26 @@ def element_mul_kernel(
125
127
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
126
128
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
127
129
  tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
130
+
131
+
132
+ def get_npu_core_count(default: int = 20) -> int:
133
+ """Return NPU vector core count.
134
+ Fallback to `default` if Triton runtime or NPU device is unavailable.
135
+ """
136
+ try:
137
+ utils = triton.runtime.driver.active.utils
138
+ props = utils.get_device_properties(0)
139
+ return int(props.get("num_vectorcore", default))
140
+ except Exception:
141
+ return default
142
+
143
+
144
+ def set_large_grf_mode(kernel_args: dict):
145
+ """Set large GRF mode for XPU devices."""
146
+ # On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
147
+ # triton XPU installed from source will be called `triton`.
148
+ if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
149
+ kernel_args["grf_mode"] = "256"
150
+ else:
151
+ # API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
152
+ kernel_args["grf_mode"] = "large"
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
33
33
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
34
34
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
35
35
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401
36
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
37
38
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
38
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
41
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
42
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
44
46
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
47
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
@@ -110,6 +112,7 @@ def __getattr__(name: str):
110
112
  "apply_liger_kernel_to_glm4",
111
113
  "apply_liger_kernel_to_glm4v",
112
114
  "apply_liger_kernel_to_glm4v_moe",
115
+ "apply_liger_kernel_to_gpt_oss",
113
116
  "apply_liger_kernel_to_granite",
114
117
  "apply_liger_kernel_to_internvl",
115
118
  "apply_liger_kernel_to_llama",
@@ -134,6 +137,7 @@ def __getattr__(name: str):
134
137
  "apply_liger_kernel_to_smolvlm",
135
138
  "apply_liger_kernel_to_hunyuan_v1_dense",
136
139
  "apply_liger_kernel_to_hunyuan_v1_moe",
140
+ "apply_liger_kernel_to_exaone4",
137
141
  }
138
142
 
139
143
  if name in monkey_patch_symbols:
@@ -187,6 +191,7 @@ if _TRANSFORMERS_AVAILABLE:
187
191
  "apply_liger_kernel_to_glm4",
188
192
  "apply_liger_kernel_to_glm4v",
189
193
  "apply_liger_kernel_to_glm4v_moe",
194
+ "apply_liger_kernel_to_gpt_oss",
190
195
  "apply_liger_kernel_to_granite",
191
196
  "apply_liger_kernel_to_internvl",
192
197
  "apply_liger_kernel_to_llama",
@@ -211,5 +216,6 @@ if _TRANSFORMERS_AVAILABLE:
211
216
  "apply_liger_kernel_to_smolvlm",
212
217
  "apply_liger_kernel_to_hunyuan_v1_dense",
213
218
  "apply_liger_kernel_to_hunyuan_v1_moe",
219
+ "apply_liger_kernel_to_exaone4",
214
220
  ]
215
221
  )
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import logging
2
3
 
3
4
  from transformers import AutoConfig
4
5
  from transformers import AutoModelForCausalLM
@@ -6,6 +7,8 @@ from transformers import AutoModelForCausalLM
6
7
  from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
8
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  def _get_model_config(model_dir, **model_init_kwargs):
11
14
  config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
@@ -36,3 +39,21 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
36
39
  applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
37
40
 
38
41
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
42
+
43
+ @classmethod
44
+ def from_config(cls, config, **kwargs):
45
+ model_type = getattr(config, "model_type", None)
46
+ if not model_type:
47
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
48
+ return
49
+ model_type = config.model_type
50
+
51
+ _apply_liger_kernel(model_type, **kwargs)
52
+
53
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
54
+ # model initialization errors otherwise
55
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
56
+ apply_fn_signature = inspect.signature(apply_fn)
57
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
58
+
59
+ return super().from_config(config, **applicable_kwargs)
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
5
+ from liger_kernel.ops import LigerCrossEntropyFunction
6
6
  from liger_kernel.transformers.functional import CrossEntropyOutput
7
7
 
8
8
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.dyt import LigerDyTFunction
4
+ from liger_kernel.ops import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
@@ -3,7 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
6
+ from liger_kernel.ops import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
@@ -3,26 +3,26 @@ from typing import Optional
3
3
 
4
4
  import torch
5
5
 
6
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
7
- from liger_kernel.ops.dyt import LigerDyTFunction
8
- from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
9
- from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
10
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
11
- from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
12
- from liger_kernel.ops.geglu import LigerGELUMulFunction
13
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
14
- from liger_kernel.ops.jsd import LigerJSDFunction
15
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
16
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
17
- from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
18
- from liger_kernel.ops.poly_norm import LigerPolyNormFunction
19
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
20
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
21
- from liger_kernel.ops.rope import LigerRopeFunction
22
- from liger_kernel.ops.softmax import LigerSoftmaxFunction
23
- from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
24
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
25
- from liger_kernel.ops.tvd import LigerTVDLossFunction
6
+ from liger_kernel.ops import LigerCrossEntropyFunction
7
+ from liger_kernel.ops import LigerDyTFunction
8
+ from liger_kernel.ops import LigerFusedAddRMSNormFunction
9
+ from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
10
+ from liger_kernel.ops import LigerFusedLinearJSDFunction
11
+ from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
12
+ from liger_kernel.ops import LigerGELUMulFunction
13
+ from liger_kernel.ops import LigerGroupNormFunction
14
+ from liger_kernel.ops import LigerJSDFunction
15
+ from liger_kernel.ops import LigerKLDivLossFunction
16
+ from liger_kernel.ops import LigerLayerNormFunction
17
+ from liger_kernel.ops import LigerMultiTokenAttentionFunction
18
+ from liger_kernel.ops import LigerPolyNormFunction
19
+ from liger_kernel.ops import LigerQwen2VLMRopeFunction
20
+ from liger_kernel.ops import LigerRMSNormFunction
21
+ from liger_kernel.ops import LigerRopeFunction
22
+ from liger_kernel.ops import LigerSiLUMulFunction
23
+ from liger_kernel.ops import LigerSoftmaxFunction
24
+ from liger_kernel.ops import LigerSparsemaxFunction
25
+ from liger_kernel.ops import LigerTVDLossFunction
26
26
 
27
27
 
28
28
  @dataclass