liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
liger_kernel/ops/dyt.py CHANGED
@@ -6,9 +6,11 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.ops.utils import get_npu_core_count
9
10
  from liger_kernel.ops.utils import infer_device
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import tanh
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
125
127
  NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
126
128
  elif device == "xpu":
127
129
  NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128
-
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_core_count()
129
132
  da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130
133
  dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131
134
  db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
@@ -8,9 +8,12 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import get_npu_core_count
12
+ from liger_kernel.ops.utils import set_large_grf_mode
11
13
  from liger_kernel.ops.utils import torch_to_triton_dtype
14
+ from liger_kernel.utils import is_npu_available
12
15
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
16
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
17
  try:
15
18
  # typical import path with dispatch available
16
19
  from triton.language.extra.libdevice import rsqrt
@@ -160,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
160
163
 
161
164
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
162
165
 
163
- dY_ptr += row_start * dY_row_stride
164
- dX_ptr += row_start * dX_row_stride
165
- if has_dS_out:
166
- dS_out_ptr += row_start * dS_out_row_stride
167
-
168
- X_ptr += row_start * X_row_stride
169
- RSTD_ptr += row_start
170
-
171
166
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
172
167
  W_row = W_row + offset
173
168
 
174
- for _ in range(row_start, row_end):
175
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
176
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
169
+ for row_idx in range(row_start, row_end):
170
+ dy_base = dY_ptr + row_idx * dY_row_stride
171
+ dx_base = dX_ptr + row_idx * dX_row_stride
172
+
173
+ x_base = X_ptr + row_idx * X_row_stride
174
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
175
+
176
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
177
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
177
178
 
178
179
  # Get cached rms
179
- rstd_row = tl.load(RSTD_ptr)
180
+ rstd_row = tl.load(rstd_base)
180
181
 
181
182
  X_row = X_row.to(tl.float32)
182
183
 
@@ -193,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
193
194
  dX_row = rstd_row * m
194
195
 
195
196
  if has_dS_out:
196
- dS_out_row = tl.load(dS_out_ptr + col_offsets, mask=mask, other=0.0)
197
+ ds_base = dS_out_ptr + row_idx * dS_out_row_stride
198
+ dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
197
199
  dX_row += (rstd_row) * (
198
200
  -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
199
201
  ) + dS_out_row
200
- dS_out_ptr += dS_out_row_stride
201
202
  else:
202
203
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
203
204
 
@@ -208,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
208
209
  # here X_row is already in fp32 (see previous if block)
209
210
  dW_row += dY_row * (X_row * rstd_row)
210
211
 
211
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
212
-
213
- dY_ptr += dY_row_stride
214
- dX_ptr += dX_row_stride
215
- X_ptr += X_row_stride
216
- RSTD_ptr += RSTD_row_stride
212
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
217
213
 
218
214
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
219
215
 
@@ -252,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
252
248
  # XPU-specific optimization
253
249
  kernel_args = {}
254
250
  if X.device.type == "xpu":
255
- kernel_args["grf_mode"] = "large"
251
+ set_large_grf_mode(kernel_args)
256
252
 
257
253
  # TODO: add _block_fused_add_rms_norm_forward_kernel
258
254
  _fused_add_rms_norm_forward_kernel[(n_rows,)](
@@ -293,6 +289,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293
289
  sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294
290
  elif S.device.type == "xpu":
295
291
  sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
292
+ elif S.device.type == "npu":
293
+ sm_count = get_npu_core_count()
296
294
 
297
295
  # fp32 for numerical stability especially.
298
296
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -310,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
310
308
  # XPU-specific optimization
311
309
  kernel_args = {}
312
310
  if S.device.type == "xpu":
313
- kernel_args["grf_mode"] = "large"
311
+ set_large_grf_mode(kernel_args)
314
312
 
315
313
  # TODO: add _block_fused_add_rms_norm_backward_kernel
316
314
  _fused_add_rms_norm_backward_kernel[grid](
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
6
6
  from liger_kernel.ops.utils import amp_custom_fwd
7
7
  from liger_kernel.ops.utils import element_mul_kernel
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
12
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
13
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
- MAX_FUSED_SIZE = 65536 // 2
14
+ MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
14
15
 
15
16
 
16
17
  def fused_linear_cross_entropy_forward(
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
68
  tanh_result = tanh(tanh_arg)
68
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
69
71
 
70
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
71
73
 
72
74
  # Gradient w.r.t. a can be computed with:
73
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
80
  da_row = dc_row * b_row * (term1 + term2)
79
81
 
80
82
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
82
84
 
83
85
 
84
86
  def geglu_forward(a, b):
@@ -6,8 +6,10 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import infer_device
10
+ from liger_kernel.utils import is_npu_available
9
11
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
13
  try:
12
14
  # typical import path with dispatch available
13
15
  from triton.language.extra.libdevice import rsqrt
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
19
  else:
18
20
  from triton.language.math import rsqrt
19
21
 
20
- MAX_FUSED_SIZE = 65536
22
+ if infer_device() == "npu":
23
+ MAX_FUSED_SIZE = 16384 # 8192
24
+ else:
25
+ MAX_FUSED_SIZE = 65536
21
26
 
22
27
 
23
28
  @triton.jit
@@ -77,15 +82,14 @@ def _group_norm_forward_kernel(
77
82
  for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
83
  W = tl.load(W_ptr + channel_idx)
79
84
  B = tl.load(B_ptr + channel_idx)
80
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
85
+ # Calculate channel offset within the group
86
+ channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
87
+ for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
81
88
  hidden_size_offsets = i + block_range
82
89
  mask = hidden_size_offsets < hidden_size_per_channel
83
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
90
+ X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
84
91
  Y = (X - m) * rstd * W + B
85
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
-
87
- X_ptr += hidden_size_per_channel
88
- Y_ptr += hidden_size_per_channel
92
+ tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
89
93
 
90
94
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
95
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
21
21
  return num_warps
22
22
 
23
23
 
24
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
30
 
26
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
32
 
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
116
121
 
117
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
123
  BT, V = y_pred.shape
119
- BLOCK_SIZE = (
120
- min(8192, triton.next_power_of_2(V))
121
- if infer_device() == "xpu"
122
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
- )
124
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
124
125
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
126
 
126
127
  grid = (BT,)
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
159
160
 
160
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
162
  BT, V = target.shape
162
- BLOCK_SIZE = (
163
- min(8192, triton.next_power_of_2(V))
164
- if infer_device() == "xpu"
165
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
- )
163
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
167
164
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
165
 
169
166
  grid = (BT,)
@@ -8,8 +8,11 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import get_npu_core_count
12
+ from liger_kernel.ops.utils import set_large_grf_mode
13
+ from liger_kernel.utils import is_npu_available
11
14
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
16
  try:
14
17
  # typical import path with dispatch available
15
18
  from triton.language.extra.libdevice import rsqrt
@@ -123,14 +126,14 @@ def _layer_norm_backward_kernel(
123
126
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
124
127
  w_f32 = w.to(tl.float32)
125
128
 
126
- # Calculate pointers for this specific row
127
- row_X_ptr = X_ptr + row_start * stride_x
128
- row_DX_ptr = DX_ptr + row_start * stride_dx
129
- row_DY_ptr = DY_ptr + row_start * stride_dy
130
- row_Mean_ptr = Mean_ptr + row_start
131
- row_RSTD_ptr = RSTD_ptr + row_start
129
+ for row_idx in range(row_start, row_end):
130
+ # Calculate pointers for this specific row
131
+ row_X_ptr = X_ptr + row_idx * stride_x
132
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
133
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
134
+ row_Mean_ptr = Mean_ptr + row_idx * stride_mean
135
+ row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
132
136
 
133
- for _ in range(row_start, row_end):
134
137
  # Load data for this row
135
138
  x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
136
139
  dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
@@ -159,12 +162,6 @@ def _layer_norm_backward_kernel(
159
162
  dW_row += dw
160
163
  db_row += db
161
164
 
162
- row_X_ptr += stride_x
163
- row_DX_ptr += stride_dx
164
- row_DY_ptr += stride_dy
165
- row_Mean_ptr += stride_mean
166
- row_RSTD_ptr += stride_rstd
167
-
168
165
  tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
169
166
  tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
170
167
 
@@ -203,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
203
200
  # XPU-specific optimization
204
201
  kernel_args = {}
205
202
  if X.device.type == "xpu":
206
- kernel_args["grf_mode"] = "large"
203
+ set_large_grf_mode(kernel_args)
207
204
 
208
205
  # Launch kernel with one thread block per row for optimal performance
209
206
  grid = (n_rows,)
@@ -253,6 +250,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
253
250
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
251
  elif X.device.type == "xpu":
255
252
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
253
+ elif X.device.type == "npu":
254
+ sm_count = get_npu_core_count()
256
255
 
257
256
  # fp32 for numerical stability especially.
258
257
  _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -271,7 +270,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
271
270
  kernel_args = {"num_warps": num_warps}
272
271
  # XPU-specific optimization
273
272
  if X.device.type == "xpu":
274
- kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
273
+ kernel_args.update({"num_warps": 32, "num_stages": 4})
274
+ set_large_grf_mode(kernel_args)
275
275
 
276
276
  # Launch kernel with one thread block per row for optimal performance
277
277
  _layer_norm_backward_kernel[grid](
@@ -300,6 +300,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
300
300
  DX = DX.view(*shape)
301
301
  DW = _DW.sum(dim=0).to(W.dtype)
302
302
  DB = _DB.sum(dim=0).to(B.dtype)
303
+
303
304
  return DX, DW, DB
304
305
 
305
306
 
@@ -7,8 +7,11 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import get_npu_core_count
11
+ from liger_kernel.ops.utils import set_large_grf_mode
12
+ from liger_kernel.utils import is_npu_available
10
13
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
14
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
15
  try:
13
16
  from triton.language.extra.libdevice import rsqrt
14
17
  except ModuleNotFoundError:
@@ -138,20 +141,19 @@ def _poly_norm_backward_kernel(
138
141
  w1 = tl.load(W_ptr + 1).to(tl.float32)
139
142
  w2 = tl.load(W_ptr + 2).to(tl.float32)
140
143
 
141
- dY_ptr += row_start * dY_row_stride
142
- dX_ptr += row_start * dX_row_stride
143
- X_ptr += row_start * X_row_stride
144
- RSTD_ptr += row_start * RSTD_row_stride
144
+ for row_idx in range(row_start, row_end):
145
+ dy_base = dY_ptr + row_idx * dY_row_stride
146
+ x_base = X_ptr + row_idx * X_row_stride
147
+ dx_base = dX_ptr + row_idx * dX_row_stride
148
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
145
149
 
146
- for _ in range(row_start, row_end):
147
- # Load input and gradient
148
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
152
 
151
153
  # Load cached rstd values
152
- rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
- rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
- rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
154
+ rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
155
+ rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
156
+ rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
155
157
 
156
158
  # Compute powers
157
159
  X_pow3 = X_row * X_row * X_row
@@ -188,13 +190,7 @@ def _poly_norm_backward_kernel(
188
190
  dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
191
 
190
192
  # Store gradient
191
- tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
-
193
- # Update pointers
194
- dY_ptr += dY_row_stride
195
- dX_ptr += dX_row_stride
196
- X_ptr += X_row_stride
197
- RSTD_ptr += RSTD_row_stride
193
+ tl.store(dx_base + col_offsets, dX_row, mask=mask)
198
194
 
199
195
  # Store accumulated gradients (scalars)
200
196
  tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
@@ -237,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
237
233
  # XPU-specific optimization
238
234
  kernel_args = {}
239
235
  if X.device.type == "xpu":
240
- kernel_args["grf_mode"] = "large"
236
+ set_large_grf_mode(kernel_args)
241
237
 
242
238
  # Launch kernel
243
239
  _poly_norm_forward_kernel[(n_rows,)](
@@ -290,6 +286,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
290
286
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
287
  elif X.device.type == "xpu":
292
288
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
289
+ elif X.device.type == "npu":
290
+ sm_count = get_npu_core_count()
293
291
 
294
292
  # Allocate or reuse gradients
295
293
  if in_place is True:
@@ -306,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
306
304
  # XPU-specific optimization
307
305
  kernel_args = {}
308
306
  if X.device.type == "xpu":
309
- kernel_args["grf_mode"] = "large"
307
+ set_large_grf_mode(kernel_args)
310
308
 
311
309
  # Launch backward kernel
312
310
  _poly_norm_backward_kernel[grid](