liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +65 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +43 -13
- liger_kernel/ops/geglu.py +2 -1
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +7 -2
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +27 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +29 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +594 -19
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/utils.py +25 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +4 -1
- liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
|
@@ -27,10 +27,16 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
29
|
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
30
31
|
):
|
|
31
32
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
32
36
|
device = _input.device
|
|
33
37
|
|
|
38
|
+
input_requires_grad = _input.requires_grad
|
|
39
|
+
|
|
34
40
|
# inputs have shape: BT x H
|
|
35
41
|
# materialized activations will have shape: BT x V
|
|
36
42
|
# the increase in memory = BT x V
|
|
@@ -49,15 +55,20 @@ def fused_linear_cross_entropy_forward(
|
|
|
49
55
|
grad_input = torch.zeros_like(_input, device=device)
|
|
50
56
|
|
|
51
57
|
# we use fp32 for loss and gradients accumulator
|
|
52
|
-
if
|
|
53
|
-
|
|
54
|
-
|
|
58
|
+
if input_requires_grad:
|
|
59
|
+
if accum_dtype is None:
|
|
60
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
61
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
62
|
+
else:
|
|
63
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
64
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
55
65
|
else:
|
|
56
|
-
grad_weight =
|
|
57
|
-
grad_bias =
|
|
66
|
+
grad_weight = None
|
|
67
|
+
grad_bias = None
|
|
58
68
|
|
|
59
69
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
60
70
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
61
72
|
|
|
62
73
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
63
74
|
target_mask = target != ignore_index
|
|
@@ -123,6 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
123
134
|
# unreduced loss
|
|
124
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
125
136
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
126
138
|
|
|
127
139
|
# ensure _input and target are contiguous
|
|
128
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -138,6 +150,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
138
150
|
loss_ptr=loss_1d_slice,
|
|
139
151
|
z_loss_ptr=z_loss_1d_slice,
|
|
140
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
141
157
|
n_cols=V,
|
|
142
158
|
n_non_ignore=total_n_non_ignore,
|
|
143
159
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -148,9 +164,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
148
164
|
reduction=reduction,
|
|
149
165
|
softcap=softcap,
|
|
150
166
|
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
151
168
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
152
169
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
153
|
-
HAS_GRADIENTS=
|
|
170
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
154
171
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
155
172
|
num_warps=32 if not is_hip() else 16,
|
|
156
173
|
)
|
|
@@ -164,6 +181,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
164
181
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
165
182
|
if return_z_loss:
|
|
166
183
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
167
186
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
168
187
|
|
|
169
188
|
# Apply token scaling to gradients if requested
|
|
@@ -172,12 +191,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
172
191
|
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
173
192
|
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
174
193
|
|
|
175
|
-
|
|
194
|
+
if input_requires_grad:
|
|
195
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
176
196
|
|
|
177
|
-
if grad_weight is not None and
|
|
197
|
+
if grad_weight is not None and input_requires_grad:
|
|
178
198
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
179
199
|
|
|
180
|
-
if bias is not None and
|
|
200
|
+
if bias is not None and input_requires_grad:
|
|
181
201
|
torch.add(
|
|
182
202
|
input=grad_bias,
|
|
183
203
|
other=grad_logits_chunk.sum(dim=0),
|
|
@@ -194,15 +214,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
194
214
|
# Return per-token losses
|
|
195
215
|
loss = loss_1d
|
|
196
216
|
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
197
218
|
else:
|
|
198
219
|
loss = torch.sum(loss_1d)
|
|
199
220
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
200
223
|
|
|
201
224
|
# Cast back to original dtype
|
|
202
225
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
203
226
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
204
227
|
|
|
205
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
206
229
|
|
|
207
230
|
|
|
208
231
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -270,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
270
293
|
return_z_loss: bool = False,
|
|
271
294
|
accum_dtype=None,
|
|
272
295
|
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
273
297
|
):
|
|
274
298
|
"""
|
|
275
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -293,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
293
317
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
294
318
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
295
319
|
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
296
321
|
"""
|
|
297
322
|
|
|
298
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
299
324
|
_input=_input,
|
|
300
325
|
weight=weight,
|
|
301
326
|
target=target,
|
|
@@ -309,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
309
334
|
return_z_loss=return_z_loss,
|
|
310
335
|
accum_dtype=accum_dtype,
|
|
311
336
|
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
312
338
|
)
|
|
313
339
|
# downcast to dtype and store for backward
|
|
314
340
|
ctx.save_for_backward(
|
|
@@ -317,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
317
343
|
grad_bias.detach() if bias is not None else None,
|
|
318
344
|
)
|
|
319
345
|
ctx.return_z_loss = return_z_loss
|
|
320
|
-
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
321
348
|
|
|
322
349
|
@staticmethod
|
|
323
350
|
@amp_custom_bwd
|
|
324
|
-
def backward(ctx, grad_output, grad_output2):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
325
352
|
if ctx.return_z_loss:
|
|
326
353
|
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
327
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
328
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
329
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -342,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
342
371
|
None,
|
|
343
372
|
None,
|
|
344
373
|
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
345
375
|
)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -7,8 +8,9 @@ import triton.language as tl
|
|
|
7
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
9
|
from liger_kernel.ops.utils import compare_version
|
|
9
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -85,68 +87,87 @@ def _layer_norm_forward_kernel(
|
|
|
85
87
|
@triton.jit
|
|
86
88
|
def _layer_norm_backward_kernel(
|
|
87
89
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
90
|
+
stride_x, # stride of each row in input
|
|
88
91
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
89
92
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
93
|
+
stride_mean, # stride of each row in mean
|
|
90
94
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
95
|
+
stride_rstd, # stride of each row in rstd
|
|
91
96
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
97
|
+
stride_dx, # stride of each row in input grad
|
|
92
98
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
99
|
+
stride_dw, # stride of each row in weights grad
|
|
93
100
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
101
|
+
stride_db, # stride of each row in bias grad
|
|
94
102
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
95
|
-
stride_x, # stride of each row in input
|
|
96
|
-
stride_dx, # stride of each row in input grad
|
|
97
103
|
stride_dy, # stride of each row in output grad
|
|
104
|
+
n_rows,
|
|
98
105
|
n_cols,
|
|
106
|
+
rows_per_program: tl.constexpr,
|
|
99
107
|
BLOCK_SIZE: tl.constexpr,
|
|
100
|
-
dtype: tl.constexpr,
|
|
101
|
-
atomic_dtype: tl.constexpr,
|
|
102
108
|
):
|
|
103
109
|
"""
|
|
104
110
|
References:
|
|
105
111
|
https://arxiv.org/abs/1607.06450
|
|
106
112
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
107
113
|
"""
|
|
108
|
-
|
|
114
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
115
|
+
row_start = row_block_id * rows_per_program
|
|
116
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
117
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
110
118
|
mask = cols < n_cols
|
|
111
119
|
|
|
120
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
122
|
+
|
|
112
123
|
# Pre-load weights once (same optimization as forward pass)
|
|
113
124
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
125
|
w_f32 = w.to(tl.float32)
|
|
115
126
|
|
|
116
127
|
# Calculate pointers for this specific row
|
|
117
|
-
row_X_ptr = X_ptr +
|
|
118
|
-
row_DX_ptr = DX_ptr +
|
|
119
|
-
row_DY_ptr = DY_ptr +
|
|
120
|
-
row_Mean_ptr = Mean_ptr +
|
|
121
|
-
row_RSTD_ptr = RSTD_ptr +
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
128
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
129
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
130
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
131
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
132
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
133
|
+
|
|
134
|
+
for _ in range(row_start, row_end):
|
|
135
|
+
# Load data for this row
|
|
136
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
138
|
+
mean = tl.load(row_Mean_ptr)
|
|
139
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
140
|
+
|
|
141
|
+
# Convert to fp32 for numerical stability
|
|
142
|
+
x_f32 = x.to(tl.float32)
|
|
143
|
+
dy_f32 = dy.to(tl.float32)
|
|
144
|
+
mean_f32 = mean.to(tl.float32)
|
|
145
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
146
|
+
|
|
147
|
+
# Compute backward pass for this row
|
|
148
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
149
|
+
wdy = w_f32 * dy_f32
|
|
150
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
151
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
152
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
153
|
+
|
|
154
|
+
# Store input gradient
|
|
155
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
156
|
+
|
|
157
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
158
|
+
dw = dy_f32 * x_hat
|
|
159
|
+
db = dy_f32
|
|
160
|
+
dW_row += dw
|
|
161
|
+
db_row += db
|
|
162
|
+
|
|
163
|
+
row_X_ptr += stride_x
|
|
164
|
+
row_DX_ptr += stride_dx
|
|
165
|
+
row_DY_ptr += stride_dy
|
|
166
|
+
row_Mean_ptr += stride_mean
|
|
167
|
+
row_RSTD_ptr += stride_rstd
|
|
168
|
+
|
|
169
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
150
171
|
|
|
151
172
|
|
|
152
173
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -228,31 +249,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
228
249
|
dY = dY.view(-1, dim)
|
|
229
250
|
n_rows, n_cols = dY.shape
|
|
230
251
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
252
|
+
sm_count = 1
|
|
253
|
+
if X.device.type == "cuda":
|
|
254
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
255
|
+
elif X.device.type == "xpu":
|
|
256
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
257
|
+
|
|
258
|
+
# fp32 for numerical stability especially.
|
|
259
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
260
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
237
261
|
|
|
238
262
|
# Calculate optimal block size and warp configuration
|
|
239
263
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
240
264
|
if n_cols > BLOCK_SIZE:
|
|
241
265
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
266
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
|
+
grid = (sm_count,)
|
|
242
268
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
tl.float32
|
|
246
|
-
if X.dtype == torch.float32
|
|
247
|
-
else tl.bfloat16
|
|
248
|
-
if X.dtype == torch.bfloat16
|
|
249
|
-
else tl.float16
|
|
250
|
-
if X.dtype == torch.float16
|
|
251
|
-
else tl.float32 # fallback
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
269
|
+
# Allocate gradient tensors
|
|
270
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
256
271
|
|
|
257
272
|
kernel_args = {"num_warps": num_warps}
|
|
258
273
|
# XPU-specific optimization
|
|
@@ -260,28 +275,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
260
275
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
261
276
|
|
|
262
277
|
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
-
grid = (n_rows,)
|
|
264
278
|
_layer_norm_backward_kernel[grid](
|
|
265
279
|
X,
|
|
280
|
+
X.stride(0),
|
|
266
281
|
W,
|
|
267
282
|
Mean,
|
|
283
|
+
Mean.stride(0),
|
|
268
284
|
RSTD,
|
|
285
|
+
RSTD.stride(0),
|
|
269
286
|
DX,
|
|
270
|
-
DW,
|
|
271
|
-
DB,
|
|
272
|
-
dY,
|
|
273
|
-
X.stride(0),
|
|
274
287
|
DX.stride(0),
|
|
288
|
+
_DW,
|
|
289
|
+
_DW.stride(0),
|
|
290
|
+
_DB,
|
|
291
|
+
_DB.stride(0),
|
|
292
|
+
dY,
|
|
275
293
|
dY.stride(0),
|
|
294
|
+
n_rows,
|
|
276
295
|
n_cols,
|
|
296
|
+
rows_per_program=rows_per_program,
|
|
277
297
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
278
|
-
dtype=triton_dtype,
|
|
279
|
-
atomic_dtype=atomic_dtype,
|
|
280
298
|
**kernel_args,
|
|
281
299
|
)
|
|
282
300
|
|
|
283
301
|
DX = DX.view(*shape)
|
|
284
|
-
|
|
302
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
304
|
+
return DX, DW, DB
|
|
285
305
|
|
|
286
306
|
|
|
287
307
|
class LigerLayerNormFunction(torch.autograd.Function):
|