liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +75 -12
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +131 -49
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +30 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +48 -25
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +26 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +702 -48
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
|
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
11
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
12
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
12
14
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
16
|
try:
|
|
15
17
|
# typical import path with dispatch available
|
|
16
18
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
293
295
|
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
|
|
294
296
|
elif S.device.type == "xpu":
|
|
295
297
|
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
|
|
298
|
+
elif S.device.type == "npu":
|
|
299
|
+
sm_count = get_npu_multi_processor_count()
|
|
296
300
|
|
|
297
301
|
# fp32 for numerical stability especially.
|
|
298
302
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
6
6
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
7
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
11
12
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
12
13
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
13
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
14
|
+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def fused_linear_cross_entropy_forward(
|
|
@@ -27,10 +28,16 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
28
|
return_z_loss=False,
|
|
28
29
|
accum_dtype=None,
|
|
29
30
|
use_token_scaling=False,
|
|
31
|
+
return_token_accuracy=False,
|
|
30
32
|
):
|
|
31
33
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
34
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
35
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
36
|
+
)
|
|
32
37
|
device = _input.device
|
|
33
38
|
|
|
39
|
+
input_requires_grad = _input.requires_grad
|
|
40
|
+
|
|
34
41
|
# inputs have shape: BT x H
|
|
35
42
|
# materialized activations will have shape: BT x V
|
|
36
43
|
# the increase in memory = BT x V
|
|
@@ -49,15 +56,20 @@ def fused_linear_cross_entropy_forward(
|
|
|
49
56
|
grad_input = torch.zeros_like(_input, device=device)
|
|
50
57
|
|
|
51
58
|
# we use fp32 for loss and gradients accumulator
|
|
52
|
-
if
|
|
53
|
-
|
|
54
|
-
|
|
59
|
+
if input_requires_grad:
|
|
60
|
+
if accum_dtype is None:
|
|
61
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
62
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
63
|
+
else:
|
|
64
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
65
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
55
66
|
else:
|
|
56
|
-
grad_weight =
|
|
57
|
-
grad_bias =
|
|
67
|
+
grad_weight = None
|
|
68
|
+
grad_bias = None
|
|
58
69
|
|
|
59
70
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
60
71
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
72
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
61
73
|
|
|
62
74
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
63
75
|
target_mask = target != ignore_index
|
|
@@ -123,6 +135,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
123
135
|
# unreduced loss
|
|
124
136
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
125
137
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
138
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
126
139
|
|
|
127
140
|
# ensure _input and target are contiguous
|
|
128
141
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -138,6 +151,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
138
151
|
loss_ptr=loss_1d_slice,
|
|
139
152
|
z_loss_ptr=z_loss_1d_slice,
|
|
140
153
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
154
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
155
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
156
|
+
if return_token_accuracy
|
|
157
|
+
else 0, # always 1 if accuracy is enabled
|
|
141
158
|
n_cols=V,
|
|
142
159
|
n_non_ignore=total_n_non_ignore,
|
|
143
160
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -148,9 +165,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
148
165
|
reduction=reduction,
|
|
149
166
|
softcap=softcap,
|
|
150
167
|
RETURN_Z_LOSS=return_z_loss,
|
|
168
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
151
169
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
152
170
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
153
|
-
HAS_GRADIENTS=
|
|
171
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
154
172
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
155
173
|
num_warps=32 if not is_hip() else 16,
|
|
156
174
|
)
|
|
@@ -164,6 +182,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
164
182
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
165
183
|
if return_z_loss:
|
|
166
184
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
185
|
+
if return_token_accuracy:
|
|
186
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
167
187
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
168
188
|
|
|
169
189
|
# Apply token scaling to gradients if requested
|
|
@@ -172,12 +192,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
172
192
|
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
173
193
|
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
174
194
|
|
|
175
|
-
|
|
195
|
+
if input_requires_grad:
|
|
196
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
176
197
|
|
|
177
|
-
if grad_weight is not None and
|
|
198
|
+
if grad_weight is not None and input_requires_grad:
|
|
178
199
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
179
200
|
|
|
180
|
-
if bias is not None and
|
|
201
|
+
if bias is not None and input_requires_grad:
|
|
181
202
|
torch.add(
|
|
182
203
|
input=grad_bias,
|
|
183
204
|
other=grad_logits_chunk.sum(dim=0),
|
|
@@ -194,15 +215,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
194
215
|
# Return per-token losses
|
|
195
216
|
loss = loss_1d
|
|
196
217
|
z_loss = z_loss_1d if return_z_loss else None
|
|
218
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
197
219
|
else:
|
|
198
220
|
loss = torch.sum(loss_1d)
|
|
199
221
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
222
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
223
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
200
224
|
|
|
201
225
|
# Cast back to original dtype
|
|
202
226
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
203
227
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
204
228
|
|
|
205
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
229
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
206
230
|
|
|
207
231
|
|
|
208
232
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -270,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
270
294
|
return_z_loss: bool = False,
|
|
271
295
|
accum_dtype=None,
|
|
272
296
|
use_token_scaling: bool = False,
|
|
297
|
+
return_token_accuracy: bool = False,
|
|
273
298
|
):
|
|
274
299
|
"""
|
|
275
300
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -293,9 +318,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
293
318
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
294
319
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
295
320
|
Default: False.
|
|
321
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
296
322
|
"""
|
|
297
323
|
|
|
298
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
324
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
299
325
|
_input=_input,
|
|
300
326
|
weight=weight,
|
|
301
327
|
target=target,
|
|
@@ -309,6 +335,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
309
335
|
return_z_loss=return_z_loss,
|
|
310
336
|
accum_dtype=accum_dtype,
|
|
311
337
|
use_token_scaling=use_token_scaling,
|
|
338
|
+
return_token_accuracy=return_token_accuracy,
|
|
312
339
|
)
|
|
313
340
|
# downcast to dtype and store for backward
|
|
314
341
|
ctx.save_for_backward(
|
|
@@ -317,13 +344,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
317
344
|
grad_bias.detach() if bias is not None else None,
|
|
318
345
|
)
|
|
319
346
|
ctx.return_z_loss = return_z_loss
|
|
320
|
-
|
|
347
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
348
|
+
return loss, z_loss, token_accuracy
|
|
321
349
|
|
|
322
350
|
@staticmethod
|
|
323
351
|
@amp_custom_bwd
|
|
324
|
-
def backward(ctx, grad_output, grad_output2):
|
|
352
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
325
353
|
if ctx.return_z_loss:
|
|
326
354
|
del grad_output2 # z_loss is only for logging
|
|
355
|
+
if ctx.return_token_accuracy:
|
|
356
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
327
357
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
328
358
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
329
359
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -342,4 +372,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
342
372
|
None,
|
|
343
373
|
None,
|
|
344
374
|
None, # use_token_scaling
|
|
375
|
+
None, # return_token_accuracy
|
|
345
376
|
)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
66
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
67
68
|
tanh_result = tanh(tanh_arg)
|
|
68
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
69
71
|
|
|
70
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
71
73
|
|
|
72
74
|
# Gradient w.r.t. a can be computed with:
|
|
73
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
78
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
79
81
|
|
|
80
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
81
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def geglu_forward(a, b):
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -7,8 +8,9 @@ import triton.language as tl
|
|
|
7
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
9
|
from liger_kernel.ops.utils import compare_version
|
|
9
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -85,68 +87,87 @@ def _layer_norm_forward_kernel(
|
|
|
85
87
|
@triton.jit
|
|
86
88
|
def _layer_norm_backward_kernel(
|
|
87
89
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
90
|
+
stride_x, # stride of each row in input
|
|
88
91
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
89
92
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
93
|
+
stride_mean, # stride of each row in mean
|
|
90
94
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
95
|
+
stride_rstd, # stride of each row in rstd
|
|
91
96
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
97
|
+
stride_dx, # stride of each row in input grad
|
|
92
98
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
99
|
+
stride_dw, # stride of each row in weights grad
|
|
93
100
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
101
|
+
stride_db, # stride of each row in bias grad
|
|
94
102
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
95
|
-
stride_x, # stride of each row in input
|
|
96
|
-
stride_dx, # stride of each row in input grad
|
|
97
103
|
stride_dy, # stride of each row in output grad
|
|
104
|
+
n_rows,
|
|
98
105
|
n_cols,
|
|
106
|
+
rows_per_program: tl.constexpr,
|
|
99
107
|
BLOCK_SIZE: tl.constexpr,
|
|
100
|
-
dtype: tl.constexpr,
|
|
101
|
-
atomic_dtype: tl.constexpr,
|
|
102
108
|
):
|
|
103
109
|
"""
|
|
104
110
|
References:
|
|
105
111
|
https://arxiv.org/abs/1607.06450
|
|
106
112
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
107
113
|
"""
|
|
108
|
-
|
|
114
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
115
|
+
row_start = row_block_id * rows_per_program
|
|
116
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
117
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
110
118
|
mask = cols < n_cols
|
|
111
119
|
|
|
120
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
122
|
+
|
|
112
123
|
# Pre-load weights once (same optimization as forward pass)
|
|
113
124
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
125
|
w_f32 = w.to(tl.float32)
|
|
115
126
|
|
|
116
127
|
# Calculate pointers for this specific row
|
|
117
|
-
row_X_ptr = X_ptr +
|
|
118
|
-
row_DX_ptr = DX_ptr +
|
|
119
|
-
row_DY_ptr = DY_ptr +
|
|
120
|
-
row_Mean_ptr = Mean_ptr +
|
|
121
|
-
row_RSTD_ptr = RSTD_ptr +
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
128
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
129
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
130
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
131
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
132
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
133
|
+
|
|
134
|
+
for _ in range(row_start, row_end):
|
|
135
|
+
# Load data for this row
|
|
136
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
138
|
+
mean = tl.load(row_Mean_ptr)
|
|
139
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
140
|
+
|
|
141
|
+
# Convert to fp32 for numerical stability
|
|
142
|
+
x_f32 = x.to(tl.float32)
|
|
143
|
+
dy_f32 = dy.to(tl.float32)
|
|
144
|
+
mean_f32 = mean.to(tl.float32)
|
|
145
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
146
|
+
|
|
147
|
+
# Compute backward pass for this row
|
|
148
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
149
|
+
wdy = w_f32 * dy_f32
|
|
150
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
151
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
152
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
153
|
+
|
|
154
|
+
# Store input gradient
|
|
155
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
156
|
+
|
|
157
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
158
|
+
dw = dy_f32 * x_hat
|
|
159
|
+
db = dy_f32
|
|
160
|
+
dW_row += dw
|
|
161
|
+
db_row += db
|
|
162
|
+
|
|
163
|
+
row_X_ptr += stride_x
|
|
164
|
+
row_DX_ptr += stride_dx
|
|
165
|
+
row_DY_ptr += stride_dy
|
|
166
|
+
row_Mean_ptr += stride_mean
|
|
167
|
+
row_RSTD_ptr += stride_rstd
|
|
168
|
+
|
|
169
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
150
171
|
|
|
151
172
|
|
|
152
173
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -228,31 +249,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
228
249
|
dY = dY.view(-1, dim)
|
|
229
250
|
n_rows, n_cols = dY.shape
|
|
230
251
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
252
|
+
sm_count = 1
|
|
253
|
+
if X.device.type == "cuda":
|
|
254
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
255
|
+
elif X.device.type == "xpu":
|
|
256
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
257
|
+
|
|
258
|
+
# fp32 for numerical stability especially.
|
|
259
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
260
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
237
261
|
|
|
238
262
|
# Calculate optimal block size and warp configuration
|
|
239
263
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
240
264
|
if n_cols > BLOCK_SIZE:
|
|
241
265
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
266
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
|
+
grid = (sm_count,)
|
|
242
268
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
tl.float32
|
|
246
|
-
if X.dtype == torch.float32
|
|
247
|
-
else tl.bfloat16
|
|
248
|
-
if X.dtype == torch.bfloat16
|
|
249
|
-
else tl.float16
|
|
250
|
-
if X.dtype == torch.float16
|
|
251
|
-
else tl.float32 # fallback
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
269
|
+
# Allocate gradient tensors
|
|
270
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
256
271
|
|
|
257
272
|
kernel_args = {"num_warps": num_warps}
|
|
258
273
|
# XPU-specific optimization
|
|
@@ -260,28 +275,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
260
275
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
261
276
|
|
|
262
277
|
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
-
grid = (n_rows,)
|
|
264
278
|
_layer_norm_backward_kernel[grid](
|
|
265
279
|
X,
|
|
280
|
+
X.stride(0),
|
|
266
281
|
W,
|
|
267
282
|
Mean,
|
|
283
|
+
Mean.stride(0),
|
|
268
284
|
RSTD,
|
|
285
|
+
RSTD.stride(0),
|
|
269
286
|
DX,
|
|
270
|
-
DW,
|
|
271
|
-
DB,
|
|
272
|
-
dY,
|
|
273
|
-
X.stride(0),
|
|
274
287
|
DX.stride(0),
|
|
288
|
+
_DW,
|
|
289
|
+
_DW.stride(0),
|
|
290
|
+
_DB,
|
|
291
|
+
_DB.stride(0),
|
|
292
|
+
dY,
|
|
275
293
|
dY.stride(0),
|
|
294
|
+
n_rows,
|
|
276
295
|
n_cols,
|
|
296
|
+
rows_per_program=rows_per_program,
|
|
277
297
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
278
|
-
dtype=triton_dtype,
|
|
279
|
-
atomic_dtype=atomic_dtype,
|
|
280
298
|
**kernel_args,
|
|
281
299
|
)
|
|
282
300
|
|
|
283
301
|
DX = DX.view(*shape)
|
|
284
|
-
|
|
302
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
304
|
+
return DX, DW, DB
|
|
285
305
|
|
|
286
306
|
|
|
287
307
|
class LigerLayerNormFunction(torch.autograd.Function):
|