liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +71 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +89 -69
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +25 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +44 -26
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +17 -8
- liger_kernel/transformers/model/gemma3.py +35 -16
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +37 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +584 -49
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.3.dist-info/RECORD +0 -111
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -8,9 +8,12 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
12
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
11
13
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
14
|
+
from liger_kernel.utils import is_npu_available
|
|
12
15
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
16
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
17
|
try:
|
|
15
18
|
# typical import path with dispatch available
|
|
16
19
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -160,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
160
163
|
|
|
161
164
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
162
165
|
|
|
163
|
-
dY_ptr += row_start * dY_row_stride
|
|
164
|
-
dX_ptr += row_start * dX_row_stride
|
|
165
|
-
if has_dS_out:
|
|
166
|
-
dS_out_ptr += row_start * dS_out_row_stride
|
|
167
|
-
|
|
168
|
-
X_ptr += row_start * X_row_stride
|
|
169
|
-
RSTD_ptr += row_start
|
|
170
|
-
|
|
171
166
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
172
167
|
W_row = W_row + offset
|
|
173
168
|
|
|
174
|
-
for
|
|
175
|
-
|
|
176
|
-
|
|
169
|
+
for row_idx in range(row_start, row_end):
|
|
170
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
171
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
172
|
+
|
|
173
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
174
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
175
|
+
|
|
176
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
177
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
177
178
|
|
|
178
179
|
# Get cached rms
|
|
179
|
-
rstd_row = tl.load(
|
|
180
|
+
rstd_row = tl.load(rstd_base)
|
|
180
181
|
|
|
181
182
|
X_row = X_row.to(tl.float32)
|
|
182
183
|
|
|
@@ -193,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
193
194
|
dX_row = rstd_row * m
|
|
194
195
|
|
|
195
196
|
if has_dS_out:
|
|
196
|
-
|
|
197
|
+
ds_base = dS_out_ptr + row_idx * dS_out_row_stride
|
|
198
|
+
dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
|
|
197
199
|
dX_row += (rstd_row) * (
|
|
198
200
|
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
199
201
|
) + dS_out_row
|
|
200
|
-
dS_out_ptr += dS_out_row_stride
|
|
201
202
|
else:
|
|
202
203
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
203
204
|
|
|
@@ -208,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
208
209
|
# here X_row is already in fp32 (see previous if block)
|
|
209
210
|
dW_row += dY_row * (X_row * rstd_row)
|
|
210
211
|
|
|
211
|
-
tl.store(
|
|
212
|
-
|
|
213
|
-
dY_ptr += dY_row_stride
|
|
214
|
-
dX_ptr += dX_row_stride
|
|
215
|
-
X_ptr += X_row_stride
|
|
216
|
-
RSTD_ptr += RSTD_row_stride
|
|
212
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
217
213
|
|
|
218
214
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
219
215
|
|
|
@@ -252,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
|
|
|
252
248
|
# XPU-specific optimization
|
|
253
249
|
kernel_args = {}
|
|
254
250
|
if X.device.type == "xpu":
|
|
255
|
-
kernel_args
|
|
251
|
+
set_large_grf_mode(kernel_args)
|
|
256
252
|
|
|
257
253
|
# TODO: add _block_fused_add_rms_norm_forward_kernel
|
|
258
254
|
_fused_add_rms_norm_forward_kernel[(n_rows,)](
|
|
@@ -293,6 +289,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
293
289
|
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
|
|
294
290
|
elif S.device.type == "xpu":
|
|
295
291
|
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
|
|
292
|
+
elif S.device.type == "npu":
|
|
293
|
+
sm_count = get_npu_core_count()
|
|
296
294
|
|
|
297
295
|
# fp32 for numerical stability especially.
|
|
298
296
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -310,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
310
308
|
# XPU-specific optimization
|
|
311
309
|
kernel_args = {}
|
|
312
310
|
if S.device.type == "xpu":
|
|
313
|
-
kernel_args
|
|
311
|
+
set_large_grf_mode(kernel_args)
|
|
314
312
|
|
|
315
313
|
# TODO: add _block_fused_add_rms_norm_backward_kernel
|
|
316
314
|
_fused_add_rms_norm_backward_kernel[grid](
|
|
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
6
6
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
7
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
11
12
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
12
13
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
13
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
14
|
+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def fused_linear_cross_entropy_forward(
|
|
@@ -27,8 +28,12 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
28
|
return_z_loss=False,
|
|
28
29
|
accum_dtype=None,
|
|
29
30
|
use_token_scaling=False,
|
|
31
|
+
return_token_accuracy=False,
|
|
30
32
|
):
|
|
31
33
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
34
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
35
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
36
|
+
)
|
|
32
37
|
device = _input.device
|
|
33
38
|
|
|
34
39
|
input_requires_grad = _input.requires_grad
|
|
@@ -58,9 +63,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
58
63
|
else:
|
|
59
64
|
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
60
65
|
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
66
|
+
else:
|
|
67
|
+
grad_weight = None
|
|
68
|
+
grad_bias = None
|
|
61
69
|
|
|
62
70
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
63
71
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
72
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
64
73
|
|
|
65
74
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
66
75
|
target_mask = target != ignore_index
|
|
@@ -126,6 +135,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
126
135
|
# unreduced loss
|
|
127
136
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
128
137
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
138
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
129
139
|
|
|
130
140
|
# ensure _input and target are contiguous
|
|
131
141
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -141,6 +151,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
141
151
|
loss_ptr=loss_1d_slice,
|
|
142
152
|
z_loss_ptr=z_loss_1d_slice,
|
|
143
153
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
154
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
155
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
156
|
+
if return_token_accuracy
|
|
157
|
+
else 0, # always 1 if accuracy is enabled
|
|
144
158
|
n_cols=V,
|
|
145
159
|
n_non_ignore=total_n_non_ignore,
|
|
146
160
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -151,6 +165,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
151
165
|
reduction=reduction,
|
|
152
166
|
softcap=softcap,
|
|
153
167
|
RETURN_Z_LOSS=return_z_loss,
|
|
168
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
154
169
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
155
170
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
156
171
|
HAS_GRADIENTS=input_requires_grad,
|
|
@@ -167,6 +182,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
167
182
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
168
183
|
if return_z_loss:
|
|
169
184
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
185
|
+
if return_token_accuracy:
|
|
186
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
170
187
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
171
188
|
|
|
172
189
|
# Apply token scaling to gradients if requested
|
|
@@ -198,15 +215,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
198
215
|
# Return per-token losses
|
|
199
216
|
loss = loss_1d
|
|
200
217
|
z_loss = z_loss_1d if return_z_loss else None
|
|
218
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
201
219
|
else:
|
|
202
220
|
loss = torch.sum(loss_1d)
|
|
203
221
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
222
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
223
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
204
224
|
|
|
205
225
|
# Cast back to original dtype
|
|
206
226
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
207
227
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
208
228
|
|
|
209
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
229
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
210
230
|
|
|
211
231
|
|
|
212
232
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -274,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
274
294
|
return_z_loss: bool = False,
|
|
275
295
|
accum_dtype=None,
|
|
276
296
|
use_token_scaling: bool = False,
|
|
297
|
+
return_token_accuracy: bool = False,
|
|
277
298
|
):
|
|
278
299
|
"""
|
|
279
300
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -297,9 +318,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
297
318
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
298
319
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
299
320
|
Default: False.
|
|
321
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
300
322
|
"""
|
|
301
323
|
|
|
302
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
324
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
303
325
|
_input=_input,
|
|
304
326
|
weight=weight,
|
|
305
327
|
target=target,
|
|
@@ -313,6 +335,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
313
335
|
return_z_loss=return_z_loss,
|
|
314
336
|
accum_dtype=accum_dtype,
|
|
315
337
|
use_token_scaling=use_token_scaling,
|
|
338
|
+
return_token_accuracy=return_token_accuracy,
|
|
316
339
|
)
|
|
317
340
|
# downcast to dtype and store for backward
|
|
318
341
|
ctx.save_for_backward(
|
|
@@ -321,13 +344,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
321
344
|
grad_bias.detach() if bias is not None else None,
|
|
322
345
|
)
|
|
323
346
|
ctx.return_z_loss = return_z_loss
|
|
324
|
-
|
|
347
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
348
|
+
return loss, z_loss, token_accuracy
|
|
325
349
|
|
|
326
350
|
@staticmethod
|
|
327
351
|
@amp_custom_bwd
|
|
328
|
-
def backward(ctx, grad_output, grad_output2):
|
|
352
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
329
353
|
if ctx.return_z_loss:
|
|
330
354
|
del grad_output2 # z_loss is only for logging
|
|
355
|
+
if ctx.return_token_accuracy:
|
|
356
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
331
357
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
332
358
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
333
359
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -346,4 +372,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
346
372
|
None,
|
|
347
373
|
None,
|
|
348
374
|
None, # use_token_scaling
|
|
375
|
+
None, # return_token_accuracy
|
|
349
376
|
)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
66
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
67
68
|
tanh_result = tanh(tanh_arg)
|
|
68
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
69
71
|
|
|
70
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
71
73
|
|
|
72
74
|
# Gradient w.r.t. a can be computed with:
|
|
73
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
78
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
79
81
|
|
|
80
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
81
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def geglu_forward(a, b):
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,10 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
9
11
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
13
|
try:
|
|
12
14
|
# typical import path with dispatch available
|
|
13
15
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
|
17
19
|
else:
|
|
18
20
|
from triton.language.math import rsqrt
|
|
19
21
|
|
|
20
|
-
|
|
22
|
+
if infer_device() == "npu":
|
|
23
|
+
MAX_FUSED_SIZE = 16384 # 8192
|
|
24
|
+
else:
|
|
25
|
+
MAX_FUSED_SIZE = 65536
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
@triton.jit
|
|
@@ -77,15 +82,14 @@ def _group_norm_forward_kernel(
|
|
|
77
82
|
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
78
83
|
W = tl.load(W_ptr + channel_idx)
|
|
79
84
|
B = tl.load(B_ptr + channel_idx)
|
|
80
|
-
|
|
85
|
+
# Calculate channel offset within the group
|
|
86
|
+
channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
|
|
87
|
+
for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
81
88
|
hidden_size_offsets = i + block_range
|
|
82
89
|
mask = hidden_size_offsets < hidden_size_per_channel
|
|
83
|
-
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
90
|
+
X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
|
|
84
91
|
Y = (X - m) * rstd * W + B
|
|
85
|
-
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
86
|
-
|
|
87
|
-
X_ptr += hidden_size_per_channel
|
|
88
|
-
Y_ptr += hidden_size_per_channel
|
|
92
|
+
tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
|
|
89
93
|
|
|
90
94
|
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
91
95
|
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
|
21
21
|
return num_warps
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
if infer_device() == "xpu":
|
|
25
|
+
MAX_FUSED_SIZE = 8192
|
|
26
|
+
elif infer_device() == "npu":
|
|
27
|
+
MAX_FUSED_SIZE = 8192
|
|
28
|
+
else:
|
|
29
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
25
30
|
|
|
26
31
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
27
32
|
|
|
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
|
|
|
116
121
|
|
|
117
122
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
118
123
|
BT, V = y_pred.shape
|
|
119
|
-
BLOCK_SIZE = (
|
|
120
|
-
min(8192, triton.next_power_of_2(V))
|
|
121
|
-
if infer_device() == "xpu"
|
|
122
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
-
)
|
|
124
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
124
125
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
125
126
|
|
|
126
127
|
grid = (BT,)
|
|
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
159
160
|
|
|
160
161
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
161
162
|
BT, V = target.shape
|
|
162
|
-
BLOCK_SIZE = (
|
|
163
|
-
min(8192, triton.next_power_of_2(V))
|
|
164
|
-
if infer_device() == "xpu"
|
|
165
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
-
)
|
|
163
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
167
164
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
168
165
|
|
|
169
166
|
grid = (BT,)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -7,8 +8,11 @@ import triton.language as tl
|
|
|
7
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
9
|
from liger_kernel.ops.utils import compare_version
|
|
9
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
12
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
10
14
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
16
|
try:
|
|
13
17
|
# typical import path with dispatch available
|
|
14
18
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -85,68 +89,81 @@ def _layer_norm_forward_kernel(
|
|
|
85
89
|
@triton.jit
|
|
86
90
|
def _layer_norm_backward_kernel(
|
|
87
91
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
92
|
+
stride_x, # stride of each row in input
|
|
88
93
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
89
94
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
95
|
+
stride_mean, # stride of each row in mean
|
|
90
96
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
97
|
+
stride_rstd, # stride of each row in rstd
|
|
91
98
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
99
|
+
stride_dx, # stride of each row in input grad
|
|
92
100
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
101
|
+
stride_dw, # stride of each row in weights grad
|
|
93
102
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
103
|
+
stride_db, # stride of each row in bias grad
|
|
94
104
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
95
|
-
stride_x, # stride of each row in input
|
|
96
|
-
stride_dx, # stride of each row in input grad
|
|
97
105
|
stride_dy, # stride of each row in output grad
|
|
106
|
+
n_rows,
|
|
98
107
|
n_cols,
|
|
108
|
+
rows_per_program: tl.constexpr,
|
|
99
109
|
BLOCK_SIZE: tl.constexpr,
|
|
100
|
-
dtype: tl.constexpr,
|
|
101
|
-
atomic_dtype: tl.constexpr,
|
|
102
110
|
):
|
|
103
111
|
"""
|
|
104
112
|
References:
|
|
105
113
|
https://arxiv.org/abs/1607.06450
|
|
106
114
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
107
115
|
"""
|
|
108
|
-
|
|
116
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
117
|
+
row_start = row_block_id * rows_per_program
|
|
118
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
119
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
110
120
|
mask = cols < n_cols
|
|
111
121
|
|
|
122
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
123
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
124
|
+
|
|
112
125
|
# Pre-load weights once (same optimization as forward pass)
|
|
113
126
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
127
|
w_f32 = w.to(tl.float32)
|
|
115
128
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
129
|
+
for row_idx in range(row_start, row_end):
|
|
130
|
+
# Calculate pointers for this specific row
|
|
131
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
132
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
133
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
134
|
+
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
|
|
135
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
|
|
136
|
+
|
|
137
|
+
# Load data for this row
|
|
138
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
139
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
140
|
+
mean = tl.load(row_Mean_ptr)
|
|
141
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
142
|
+
|
|
143
|
+
# Convert to fp32 for numerical stability
|
|
144
|
+
x_f32 = x.to(tl.float32)
|
|
145
|
+
dy_f32 = dy.to(tl.float32)
|
|
146
|
+
mean_f32 = mean.to(tl.float32)
|
|
147
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
148
|
+
|
|
149
|
+
# Compute backward pass for this row
|
|
150
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
151
|
+
wdy = w_f32 * dy_f32
|
|
152
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
153
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
154
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
155
|
+
|
|
156
|
+
# Store input gradient
|
|
157
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
158
|
+
|
|
159
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
160
|
+
dw = dy_f32 * x_hat
|
|
161
|
+
db = dy_f32
|
|
162
|
+
dW_row += dw
|
|
163
|
+
db_row += db
|
|
164
|
+
|
|
165
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
166
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
150
167
|
|
|
151
168
|
|
|
152
169
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -183,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
183
200
|
# XPU-specific optimization
|
|
184
201
|
kernel_args = {}
|
|
185
202
|
if X.device.type == "xpu":
|
|
186
|
-
kernel_args
|
|
203
|
+
set_large_grf_mode(kernel_args)
|
|
187
204
|
|
|
188
205
|
# Launch kernel with one thread block per row for optimal performance
|
|
189
206
|
grid = (n_rows,)
|
|
@@ -228,60 +245,63 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
228
245
|
dY = dY.view(-1, dim)
|
|
229
246
|
n_rows, n_cols = dY.shape
|
|
230
247
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
248
|
+
sm_count = 1
|
|
249
|
+
if X.device.type == "cuda":
|
|
250
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
251
|
+
elif X.device.type == "xpu":
|
|
252
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
253
|
+
elif X.device.type == "npu":
|
|
254
|
+
sm_count = get_npu_core_count()
|
|
255
|
+
|
|
256
|
+
# fp32 for numerical stability especially.
|
|
257
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
258
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
237
259
|
|
|
238
260
|
# Calculate optimal block size and warp configuration
|
|
239
261
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
240
262
|
if n_cols > BLOCK_SIZE:
|
|
241
263
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
264
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
265
|
+
grid = (sm_count,)
|
|
242
266
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
tl.float32
|
|
246
|
-
if X.dtype == torch.float32
|
|
247
|
-
else tl.bfloat16
|
|
248
|
-
if X.dtype == torch.bfloat16
|
|
249
|
-
else tl.float16
|
|
250
|
-
if X.dtype == torch.float16
|
|
251
|
-
else tl.float32 # fallback
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
267
|
+
# Allocate gradient tensors
|
|
268
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
256
269
|
|
|
257
270
|
kernel_args = {"num_warps": num_warps}
|
|
258
271
|
# XPU-specific optimization
|
|
259
272
|
if X.device.type == "xpu":
|
|
260
|
-
kernel_args.update({"
|
|
273
|
+
kernel_args.update({"num_warps": 32, "num_stages": 4})
|
|
274
|
+
set_large_grf_mode(kernel_args)
|
|
261
275
|
|
|
262
276
|
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
-
grid = (n_rows,)
|
|
264
277
|
_layer_norm_backward_kernel[grid](
|
|
265
278
|
X,
|
|
279
|
+
X.stride(0),
|
|
266
280
|
W,
|
|
267
281
|
Mean,
|
|
282
|
+
Mean.stride(0),
|
|
268
283
|
RSTD,
|
|
284
|
+
RSTD.stride(0),
|
|
269
285
|
DX,
|
|
270
|
-
DW,
|
|
271
|
-
DB,
|
|
272
|
-
dY,
|
|
273
|
-
X.stride(0),
|
|
274
286
|
DX.stride(0),
|
|
287
|
+
_DW,
|
|
288
|
+
_DW.stride(0),
|
|
289
|
+
_DB,
|
|
290
|
+
_DB.stride(0),
|
|
291
|
+
dY,
|
|
275
292
|
dY.stride(0),
|
|
293
|
+
n_rows,
|
|
276
294
|
n_cols,
|
|
295
|
+
rows_per_program=rows_per_program,
|
|
277
296
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
278
|
-
dtype=triton_dtype,
|
|
279
|
-
atomic_dtype=atomic_dtype,
|
|
280
297
|
**kernel_args,
|
|
281
298
|
)
|
|
282
299
|
|
|
283
300
|
DX = DX.view(*shape)
|
|
284
|
-
|
|
301
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
302
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
303
|
+
|
|
304
|
+
return DX, DW, DB
|
|
285
305
|
|
|
286
306
|
|
|
287
307
|
class LigerLayerNormFunction(torch.autograd.Function):
|