liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +14 -4
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +17 -16
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/gemma2.py +3 -3
- liger_kernel/transformers/model/gemma3.py +11 -5
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/loss_utils.py +6 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +196 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +5 -13
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.4.dist-info/RECORD +0 -118
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
liger_kernel/ops/dyt.py
CHANGED
|
@@ -6,9 +6,11 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
9
10
|
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import tanh
|
|
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
|
125
127
|
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
126
128
|
elif device == "xpu":
|
|
127
129
|
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
128
|
-
|
|
130
|
+
elif device == "npu":
|
|
131
|
+
NUM_SMS = get_npu_core_count()
|
|
129
132
|
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
130
133
|
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
131
134
|
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
|
@@ -8,9 +8,12 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
12
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
11
13
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
14
|
+
from liger_kernel.utils import is_npu_available
|
|
12
15
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
16
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
17
|
try:
|
|
15
18
|
# typical import path with dispatch available
|
|
16
19
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -160,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
160
163
|
|
|
161
164
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
162
165
|
|
|
163
|
-
dY_ptr += row_start * dY_row_stride
|
|
164
|
-
dX_ptr += row_start * dX_row_stride
|
|
165
|
-
if has_dS_out:
|
|
166
|
-
dS_out_ptr += row_start * dS_out_row_stride
|
|
167
|
-
|
|
168
|
-
X_ptr += row_start * X_row_stride
|
|
169
|
-
RSTD_ptr += row_start
|
|
170
|
-
|
|
171
166
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
172
167
|
W_row = W_row + offset
|
|
173
168
|
|
|
174
|
-
for
|
|
175
|
-
|
|
176
|
-
|
|
169
|
+
for row_idx in range(row_start, row_end):
|
|
170
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
171
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
172
|
+
|
|
173
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
174
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
175
|
+
|
|
176
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
177
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
177
178
|
|
|
178
179
|
# Get cached rms
|
|
179
|
-
rstd_row = tl.load(
|
|
180
|
+
rstd_row = tl.load(rstd_base)
|
|
180
181
|
|
|
181
182
|
X_row = X_row.to(tl.float32)
|
|
182
183
|
|
|
@@ -193,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
193
194
|
dX_row = rstd_row * m
|
|
194
195
|
|
|
195
196
|
if has_dS_out:
|
|
196
|
-
|
|
197
|
+
ds_base = dS_out_ptr + row_idx * dS_out_row_stride
|
|
198
|
+
dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
|
|
197
199
|
dX_row += (rstd_row) * (
|
|
198
200
|
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
199
201
|
) + dS_out_row
|
|
200
|
-
dS_out_ptr += dS_out_row_stride
|
|
201
202
|
else:
|
|
202
203
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
203
204
|
|
|
@@ -208,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
208
209
|
# here X_row is already in fp32 (see previous if block)
|
|
209
210
|
dW_row += dY_row * (X_row * rstd_row)
|
|
210
211
|
|
|
211
|
-
tl.store(
|
|
212
|
-
|
|
213
|
-
dY_ptr += dY_row_stride
|
|
214
|
-
dX_ptr += dX_row_stride
|
|
215
|
-
X_ptr += X_row_stride
|
|
216
|
-
RSTD_ptr += RSTD_row_stride
|
|
212
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
217
213
|
|
|
218
214
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
219
215
|
|
|
@@ -252,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
|
|
|
252
248
|
# XPU-specific optimization
|
|
253
249
|
kernel_args = {}
|
|
254
250
|
if X.device.type == "xpu":
|
|
255
|
-
kernel_args
|
|
251
|
+
set_large_grf_mode(kernel_args)
|
|
256
252
|
|
|
257
253
|
# TODO: add _block_fused_add_rms_norm_forward_kernel
|
|
258
254
|
_fused_add_rms_norm_forward_kernel[(n_rows,)](
|
|
@@ -293,6 +289,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
293
289
|
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
|
|
294
290
|
elif S.device.type == "xpu":
|
|
295
291
|
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
|
|
292
|
+
elif S.device.type == "npu":
|
|
293
|
+
sm_count = get_npu_core_count()
|
|
296
294
|
|
|
297
295
|
# fp32 for numerical stability especially.
|
|
298
296
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -310,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
310
308
|
# XPU-specific optimization
|
|
311
309
|
kernel_args = {}
|
|
312
310
|
if S.device.type == "xpu":
|
|
313
|
-
kernel_args
|
|
311
|
+
set_large_grf_mode(kernel_args)
|
|
314
312
|
|
|
315
313
|
# TODO: add _block_fused_add_rms_norm_backward_kernel
|
|
316
314
|
_fused_add_rms_norm_backward_kernel[grid](
|
|
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
6
6
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
7
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
11
12
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
12
13
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
13
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
14
|
+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def fused_linear_cross_entropy_forward(
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
66
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
67
68
|
tanh_result = tanh(tanh_arg)
|
|
68
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
69
71
|
|
|
70
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
71
73
|
|
|
72
74
|
# Gradient w.r.t. a can be computed with:
|
|
73
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
78
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
79
81
|
|
|
80
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
81
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def geglu_forward(a, b):
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,10 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
9
11
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
13
|
try:
|
|
12
14
|
# typical import path with dispatch available
|
|
13
15
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
|
17
19
|
else:
|
|
18
20
|
from triton.language.math import rsqrt
|
|
19
21
|
|
|
20
|
-
|
|
22
|
+
if infer_device() == "npu":
|
|
23
|
+
MAX_FUSED_SIZE = 16384 # 8192
|
|
24
|
+
else:
|
|
25
|
+
MAX_FUSED_SIZE = 65536
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
@triton.jit
|
|
@@ -77,15 +82,14 @@ def _group_norm_forward_kernel(
|
|
|
77
82
|
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
78
83
|
W = tl.load(W_ptr + channel_idx)
|
|
79
84
|
B = tl.load(B_ptr + channel_idx)
|
|
80
|
-
|
|
85
|
+
# Calculate channel offset within the group
|
|
86
|
+
channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
|
|
87
|
+
for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
81
88
|
hidden_size_offsets = i + block_range
|
|
82
89
|
mask = hidden_size_offsets < hidden_size_per_channel
|
|
83
|
-
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
90
|
+
X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
|
|
84
91
|
Y = (X - m) * rstd * W + B
|
|
85
|
-
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
86
|
-
|
|
87
|
-
X_ptr += hidden_size_per_channel
|
|
88
|
-
Y_ptr += hidden_size_per_channel
|
|
92
|
+
tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
|
|
89
93
|
|
|
90
94
|
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
91
95
|
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
|
21
21
|
return num_warps
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
if infer_device() == "xpu":
|
|
25
|
+
MAX_FUSED_SIZE = 8192
|
|
26
|
+
elif infer_device() == "npu":
|
|
27
|
+
MAX_FUSED_SIZE = 8192
|
|
28
|
+
else:
|
|
29
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
25
30
|
|
|
26
31
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
27
32
|
|
|
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
|
|
|
116
121
|
|
|
117
122
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
118
123
|
BT, V = y_pred.shape
|
|
119
|
-
BLOCK_SIZE = (
|
|
120
|
-
min(8192, triton.next_power_of_2(V))
|
|
121
|
-
if infer_device() == "xpu"
|
|
122
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
-
)
|
|
124
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
124
125
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
125
126
|
|
|
126
127
|
grid = (BT,)
|
|
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
159
160
|
|
|
160
161
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
161
162
|
BT, V = target.shape
|
|
162
|
-
BLOCK_SIZE = (
|
|
163
|
-
min(8192, triton.next_power_of_2(V))
|
|
164
|
-
if infer_device() == "xpu"
|
|
165
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
-
)
|
|
163
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
167
164
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
168
165
|
|
|
169
166
|
grid = (BT,)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -8,8 +8,11 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
12
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
11
14
|
|
|
12
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
13
16
|
try:
|
|
14
17
|
# typical import path with dispatch available
|
|
15
18
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -123,14 +126,14 @@ def _layer_norm_backward_kernel(
|
|
|
123
126
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
124
127
|
w_f32 = w.to(tl.float32)
|
|
125
128
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
129
|
+
for row_idx in range(row_start, row_end):
|
|
130
|
+
# Calculate pointers for this specific row
|
|
131
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
132
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
133
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
134
|
+
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
|
|
135
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
|
|
132
136
|
|
|
133
|
-
for _ in range(row_start, row_end):
|
|
134
137
|
# Load data for this row
|
|
135
138
|
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
136
139
|
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
@@ -159,12 +162,6 @@ def _layer_norm_backward_kernel(
|
|
|
159
162
|
dW_row += dw
|
|
160
163
|
db_row += db
|
|
161
164
|
|
|
162
|
-
row_X_ptr += stride_x
|
|
163
|
-
row_DX_ptr += stride_dx
|
|
164
|
-
row_DY_ptr += stride_dy
|
|
165
|
-
row_Mean_ptr += stride_mean
|
|
166
|
-
row_RSTD_ptr += stride_rstd
|
|
167
|
-
|
|
168
165
|
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
169
166
|
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
170
167
|
|
|
@@ -203,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
203
200
|
# XPU-specific optimization
|
|
204
201
|
kernel_args = {}
|
|
205
202
|
if X.device.type == "xpu":
|
|
206
|
-
kernel_args
|
|
203
|
+
set_large_grf_mode(kernel_args)
|
|
207
204
|
|
|
208
205
|
# Launch kernel with one thread block per row for optimal performance
|
|
209
206
|
grid = (n_rows,)
|
|
@@ -253,6 +250,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
253
250
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
251
|
elif X.device.type == "xpu":
|
|
255
252
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
253
|
+
elif X.device.type == "npu":
|
|
254
|
+
sm_count = get_npu_core_count()
|
|
256
255
|
|
|
257
256
|
# fp32 for numerical stability especially.
|
|
258
257
|
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -271,7 +270,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
271
270
|
kernel_args = {"num_warps": num_warps}
|
|
272
271
|
# XPU-specific optimization
|
|
273
272
|
if X.device.type == "xpu":
|
|
274
|
-
kernel_args.update({"
|
|
273
|
+
kernel_args.update({"num_warps": 32, "num_stages": 4})
|
|
274
|
+
set_large_grf_mode(kernel_args)
|
|
275
275
|
|
|
276
276
|
# Launch kernel with one thread block per row for optimal performance
|
|
277
277
|
_layer_norm_backward_kernel[grid](
|
|
@@ -300,6 +300,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
300
300
|
DX = DX.view(*shape)
|
|
301
301
|
DW = _DW.sum(dim=0).to(W.dtype)
|
|
302
302
|
DB = _DB.sum(dim=0).to(B.dtype)
|
|
303
|
+
|
|
303
304
|
return DX, DW, DB
|
|
304
305
|
|
|
305
306
|
|
liger_kernel/ops/poly_norm.py
CHANGED
|
@@ -7,8 +7,11 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
11
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
12
|
+
from liger_kernel.utils import is_npu_available
|
|
10
13
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
15
|
try:
|
|
13
16
|
from triton.language.extra.libdevice import rsqrt
|
|
14
17
|
except ModuleNotFoundError:
|
|
@@ -138,20 +141,19 @@ def _poly_norm_backward_kernel(
|
|
|
138
141
|
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
139
142
|
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
140
143
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
144
|
+
for row_idx in range(row_start, row_end):
|
|
145
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
146
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
147
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
148
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
145
149
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
149
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
152
|
|
|
151
153
|
# Load cached rstd values
|
|
152
|
-
rstd_3 = tl.load(
|
|
153
|
-
rstd_2 = tl.load(
|
|
154
|
-
rstd_1 = tl.load(
|
|
154
|
+
rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
|
|
155
|
+
rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
|
|
156
|
+
rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
|
|
155
157
|
|
|
156
158
|
# Compute powers
|
|
157
159
|
X_pow3 = X_row * X_row * X_row
|
|
@@ -188,13 +190,7 @@ def _poly_norm_backward_kernel(
|
|
|
188
190
|
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
189
191
|
|
|
190
192
|
# Store gradient
|
|
191
|
-
tl.store(
|
|
192
|
-
|
|
193
|
-
# Update pointers
|
|
194
|
-
dY_ptr += dY_row_stride
|
|
195
|
-
dX_ptr += dX_row_stride
|
|
196
|
-
X_ptr += X_row_stride
|
|
197
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
+
tl.store(dx_base + col_offsets, dX_row, mask=mask)
|
|
198
194
|
|
|
199
195
|
# Store accumulated gradients (scalars)
|
|
200
196
|
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
@@ -237,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
|
|
|
237
233
|
# XPU-specific optimization
|
|
238
234
|
kernel_args = {}
|
|
239
235
|
if X.device.type == "xpu":
|
|
240
|
-
kernel_args
|
|
236
|
+
set_large_grf_mode(kernel_args)
|
|
241
237
|
|
|
242
238
|
# Launch kernel
|
|
243
239
|
_poly_norm_forward_kernel[(n_rows,)](
|
|
@@ -290,6 +286,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
|
290
286
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
291
287
|
elif X.device.type == "xpu":
|
|
292
288
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
289
|
+
elif X.device.type == "npu":
|
|
290
|
+
sm_count = get_npu_core_count()
|
|
293
291
|
|
|
294
292
|
# Allocate or reuse gradients
|
|
295
293
|
if in_place is True:
|
|
@@ -306,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
|
306
304
|
# XPU-specific optimization
|
|
307
305
|
kernel_args = {}
|
|
308
306
|
if X.device.type == "xpu":
|
|
309
|
-
kernel_args
|
|
307
|
+
set_large_grf_mode(kernel_args)
|
|
310
308
|
|
|
311
309
|
# Launch backward kernel
|
|
312
310
|
_poly_norm_backward_kernel[grid](
|