liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +120 -63
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
- liger_kernel/ops/geglu.py +2 -1
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +88 -70
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +7 -2
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +33 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +29 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +19 -7
- liger_kernel/transformers/model/gemma2.py +22 -7
- liger_kernel/transformers/model/gemma3.py +52 -14
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +18 -5
- liger_kernel/transformers/model/glm4v_moe.py +25 -5
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +16 -6
- liger_kernel/transformers/model/llama4.py +18 -5
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +17 -7
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +14 -5
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +16 -8
- liger_kernel/transformers/model/qwen2.py +18 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +17 -7
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +729 -4
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/utils.py +25 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
- liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -7,8 +8,9 @@ import triton.language as tl
|
|
|
7
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
9
|
from liger_kernel.ops.utils import compare_version
|
|
9
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -63,12 +65,11 @@ def _layer_norm_forward_kernel(
|
|
|
63
65
|
X_f32 = X_row.to(tl.float32)
|
|
64
66
|
|
|
65
67
|
# Compute statistics in fp32 for numerical stability
|
|
66
|
-
|
|
67
|
-
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
68
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
68
69
|
X_centered = X_f32 - mean
|
|
69
70
|
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
71
|
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
-
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) /
|
|
72
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
72
73
|
rstd = rsqrt(var + eps)
|
|
73
74
|
|
|
74
75
|
# Store statistics (convert back to original dtype only once)
|
|
@@ -86,69 +87,87 @@ def _layer_norm_forward_kernel(
|
|
|
86
87
|
@triton.jit
|
|
87
88
|
def _layer_norm_backward_kernel(
|
|
88
89
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
90
|
+
stride_x, # stride of each row in input
|
|
89
91
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
90
92
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
93
|
+
stride_mean, # stride of each row in mean
|
|
91
94
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
95
|
+
stride_rstd, # stride of each row in rstd
|
|
92
96
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
97
|
+
stride_dx, # stride of each row in input grad
|
|
93
98
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
99
|
+
stride_dw, # stride of each row in weights grad
|
|
94
100
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
101
|
+
stride_db, # stride of each row in bias grad
|
|
95
102
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
96
|
-
stride_x, # stride of each row in input
|
|
97
|
-
stride_dx, # stride of each row in input grad
|
|
98
103
|
stride_dy, # stride of each row in output grad
|
|
104
|
+
n_rows,
|
|
99
105
|
n_cols,
|
|
106
|
+
rows_per_program: tl.constexpr,
|
|
100
107
|
BLOCK_SIZE: tl.constexpr,
|
|
101
|
-
dtype: tl.constexpr,
|
|
102
|
-
atomic_dtype: tl.constexpr,
|
|
103
108
|
):
|
|
104
109
|
"""
|
|
105
110
|
References:
|
|
106
111
|
https://arxiv.org/abs/1607.06450
|
|
107
112
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
108
113
|
"""
|
|
109
|
-
|
|
114
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
115
|
+
row_start = row_block_id * rows_per_program
|
|
116
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
110
117
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
111
118
|
mask = cols < n_cols
|
|
112
119
|
|
|
120
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
122
|
+
|
|
113
123
|
# Pre-load weights once (same optimization as forward pass)
|
|
114
124
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
125
|
w_f32 = w.to(tl.float32)
|
|
116
|
-
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
126
|
|
|
118
127
|
# Calculate pointers for this specific row
|
|
119
|
-
row_X_ptr = X_ptr +
|
|
120
|
-
row_DX_ptr = DX_ptr +
|
|
121
|
-
row_DY_ptr = DY_ptr +
|
|
122
|
-
row_Mean_ptr = Mean_ptr +
|
|
123
|
-
row_RSTD_ptr = RSTD_ptr +
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
128
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
129
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
130
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
131
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
132
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
133
|
+
|
|
134
|
+
for _ in range(row_start, row_end):
|
|
135
|
+
# Load data for this row
|
|
136
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
138
|
+
mean = tl.load(row_Mean_ptr)
|
|
139
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
140
|
+
|
|
141
|
+
# Convert to fp32 for numerical stability
|
|
142
|
+
x_f32 = x.to(tl.float32)
|
|
143
|
+
dy_f32 = dy.to(tl.float32)
|
|
144
|
+
mean_f32 = mean.to(tl.float32)
|
|
145
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
146
|
+
|
|
147
|
+
# Compute backward pass for this row
|
|
148
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
149
|
+
wdy = w_f32 * dy_f32
|
|
150
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
151
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
152
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
153
|
+
|
|
154
|
+
# Store input gradient
|
|
155
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
156
|
+
|
|
157
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
158
|
+
dw = dy_f32 * x_hat
|
|
159
|
+
db = dy_f32
|
|
160
|
+
dW_row += dw
|
|
161
|
+
db_row += db
|
|
162
|
+
|
|
163
|
+
row_X_ptr += stride_x
|
|
164
|
+
row_DX_ptr += stride_dx
|
|
165
|
+
row_DY_ptr += stride_dy
|
|
166
|
+
row_Mean_ptr += stride_mean
|
|
167
|
+
row_RSTD_ptr += stride_rstd
|
|
168
|
+
|
|
169
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
152
171
|
|
|
153
172
|
|
|
154
173
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -230,31 +249,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
230
249
|
dY = dY.view(-1, dim)
|
|
231
250
|
n_rows, n_cols = dY.shape
|
|
232
251
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
252
|
+
sm_count = 1
|
|
253
|
+
if X.device.type == "cuda":
|
|
254
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
255
|
+
elif X.device.type == "xpu":
|
|
256
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
257
|
+
|
|
258
|
+
# fp32 for numerical stability especially.
|
|
259
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
260
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
239
261
|
|
|
240
262
|
# Calculate optimal block size and warp configuration
|
|
241
263
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
242
264
|
if n_cols > BLOCK_SIZE:
|
|
243
265
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
266
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
|
+
grid = (sm_count,)
|
|
244
268
|
|
|
245
|
-
#
|
|
246
|
-
|
|
247
|
-
tl.float32
|
|
248
|
-
if X.dtype == torch.float32
|
|
249
|
-
else tl.bfloat16
|
|
250
|
-
if X.dtype == torch.bfloat16
|
|
251
|
-
else tl.float16
|
|
252
|
-
if X.dtype == torch.float16
|
|
253
|
-
else tl.float32 # fallback
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
257
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
269
|
+
# Allocate gradient tensors
|
|
270
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
258
271
|
|
|
259
272
|
kernel_args = {"num_warps": num_warps}
|
|
260
273
|
# XPU-specific optimization
|
|
@@ -262,28 +275,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
262
275
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
263
276
|
|
|
264
277
|
# Launch kernel with one thread block per row for optimal performance
|
|
265
|
-
grid = (n_rows,)
|
|
266
278
|
_layer_norm_backward_kernel[grid](
|
|
267
279
|
X,
|
|
280
|
+
X.stride(0),
|
|
268
281
|
W,
|
|
269
282
|
Mean,
|
|
283
|
+
Mean.stride(0),
|
|
270
284
|
RSTD,
|
|
285
|
+
RSTD.stride(0),
|
|
271
286
|
DX,
|
|
272
|
-
DW,
|
|
273
|
-
DB,
|
|
274
|
-
dY,
|
|
275
|
-
X.stride(0),
|
|
276
287
|
DX.stride(0),
|
|
288
|
+
_DW,
|
|
289
|
+
_DW.stride(0),
|
|
290
|
+
_DB,
|
|
291
|
+
_DB.stride(0),
|
|
292
|
+
dY,
|
|
277
293
|
dY.stride(0),
|
|
294
|
+
n_rows,
|
|
278
295
|
n_cols,
|
|
296
|
+
rows_per_program=rows_per_program,
|
|
279
297
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
280
|
-
dtype=triton_dtype,
|
|
281
|
-
atomic_dtype=atomic_dtype,
|
|
282
298
|
**kernel_args,
|
|
283
299
|
)
|
|
284
300
|
|
|
285
301
|
DX = DX.view(*shape)
|
|
286
|
-
|
|
302
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
304
|
+
return DX, DW, DB
|
|
287
305
|
|
|
288
306
|
|
|
289
307
|
class LigerLayerNormFunction(torch.autograd.Function):
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
|
+
try:
|
|
15
|
+
from triton.language.extra.libdevice import rsqrt
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
18
|
+
else:
|
|
19
|
+
from triton.language.math import rsqrt
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@triton.jit
|
|
23
|
+
def _poly_norm_forward_kernel(
|
|
24
|
+
Y_ptr,
|
|
25
|
+
Y_row_stride,
|
|
26
|
+
X_ptr,
|
|
27
|
+
X_row_stride,
|
|
28
|
+
W_ptr, # weight: [3] for [w0, w1, w2]
|
|
29
|
+
B_ptr, # bias: scalar
|
|
30
|
+
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
|
|
31
|
+
RSTD_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
eps,
|
|
34
|
+
BLOCK_SIZE: tl.constexpr,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
PolyNorm formula:
|
|
38
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
39
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
40
|
+
|
|
41
|
+
Reference:
|
|
42
|
+
1. https://github.com/BryceZhuo/PolyCom/
|
|
43
|
+
2. https://arxiv.org/pdf/2411.03884
|
|
44
|
+
|
|
45
|
+
Cache rstd values for backward pass
|
|
46
|
+
"""
|
|
47
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
48
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
49
|
+
mask = col_offsets < n_cols
|
|
50
|
+
|
|
51
|
+
# Load pointers
|
|
52
|
+
Y_ptr += row_idx * Y_row_stride
|
|
53
|
+
X_ptr += row_idx * X_row_stride
|
|
54
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
55
|
+
|
|
56
|
+
# Load input row
|
|
57
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
58
|
+
|
|
59
|
+
# Load weights and bias
|
|
60
|
+
w0 = tl.load(W_ptr + 0)
|
|
61
|
+
w1 = tl.load(W_ptr + 1)
|
|
62
|
+
w2 = tl.load(W_ptr + 2)
|
|
63
|
+
b = tl.load(B_ptr)
|
|
64
|
+
|
|
65
|
+
# Compute x³, x², x
|
|
66
|
+
X_pow3 = X_row * X_row * X_row
|
|
67
|
+
X_pow2 = X_row * X_row
|
|
68
|
+
X_pow1 = X_row
|
|
69
|
+
|
|
70
|
+
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
|
|
71
|
+
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
|
|
72
|
+
rstd_3 = rsqrt(mean_square_3 + eps)
|
|
73
|
+
norm_x3 = X_pow3 * rstd_3
|
|
74
|
+
|
|
75
|
+
# Compute norm(x²)
|
|
76
|
+
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
|
|
77
|
+
rstd_2 = rsqrt(mean_square_2 + eps)
|
|
78
|
+
norm_x2 = X_pow2 * rstd_2
|
|
79
|
+
|
|
80
|
+
# Compute norm(x)
|
|
81
|
+
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
|
|
82
|
+
rstd_1 = rsqrt(mean_square_1 + eps)
|
|
83
|
+
norm_x1 = X_pow1 * rstd_1
|
|
84
|
+
|
|
85
|
+
# Cache rstd values for backward
|
|
86
|
+
tl.store(RSTD_ptr + 0, rstd_3)
|
|
87
|
+
tl.store(RSTD_ptr + 1, rstd_2)
|
|
88
|
+
tl.store(RSTD_ptr + 2, rstd_1)
|
|
89
|
+
|
|
90
|
+
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
91
|
+
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
|
|
92
|
+
|
|
93
|
+
# Store output
|
|
94
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@triton.jit
|
|
98
|
+
def _poly_norm_backward_kernel(
|
|
99
|
+
dY_ptr,
|
|
100
|
+
dY_row_stride,
|
|
101
|
+
dX_ptr,
|
|
102
|
+
dX_row_stride,
|
|
103
|
+
X_ptr,
|
|
104
|
+
X_row_stride,
|
|
105
|
+
W_ptr,
|
|
106
|
+
RSTD_ptr,
|
|
107
|
+
RSTD_row_stride,
|
|
108
|
+
dW_ptr, # shape: (n_programs, 3)
|
|
109
|
+
dW_row_stride,
|
|
110
|
+
dB_ptr, # shape: (n_programs,)
|
|
111
|
+
n_rows,
|
|
112
|
+
n_cols,
|
|
113
|
+
rows_per_program: tl.constexpr,
|
|
114
|
+
BLOCK_SIZE: tl.constexpr,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
PolyNorm Backward Kernel Gradient:
|
|
118
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
119
|
+
|
|
120
|
+
where:
|
|
121
|
+
- D_p = RMS(x^p) = 1/rstd_p
|
|
122
|
+
- S_p = sum(grad * x^p) over the row
|
|
123
|
+
- d = n_cols
|
|
124
|
+
- p ∈ {3, 2, 1}
|
|
125
|
+
"""
|
|
126
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
127
|
+
row_start = row_block_id * rows_per_program
|
|
128
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
129
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
130
|
+
mask = col_offsets < n_cols
|
|
131
|
+
|
|
132
|
+
# Initialize accumulators for weight and bias gradients (scalars)
|
|
133
|
+
dW0_acc = 0.0
|
|
134
|
+
dW1_acc = 0.0
|
|
135
|
+
dW2_acc = 0.0
|
|
136
|
+
dB_acc = 0.0
|
|
137
|
+
|
|
138
|
+
# Load weights
|
|
139
|
+
w0 = tl.load(W_ptr + 0).to(tl.float32)
|
|
140
|
+
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
141
|
+
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
142
|
+
|
|
143
|
+
dY_ptr += row_start * dY_row_stride
|
|
144
|
+
dX_ptr += row_start * dX_row_stride
|
|
145
|
+
X_ptr += row_start * X_row_stride
|
|
146
|
+
RSTD_ptr += row_start * RSTD_row_stride
|
|
147
|
+
|
|
148
|
+
for _ in range(row_start, row_end):
|
|
149
|
+
# Load input and gradient
|
|
150
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
152
|
+
|
|
153
|
+
# Load cached rstd values
|
|
154
|
+
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
|
|
155
|
+
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
|
|
156
|
+
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
|
|
157
|
+
|
|
158
|
+
# Compute powers
|
|
159
|
+
X_pow3 = X_row * X_row * X_row
|
|
160
|
+
X_pow2 = X_row * X_row
|
|
161
|
+
X_pow1 = X_row
|
|
162
|
+
|
|
163
|
+
# Accumulate bias gradient: dB = sum(dY)
|
|
164
|
+
dB_acc += tl.sum(dY_row, axis=0)
|
|
165
|
+
|
|
166
|
+
# Compute gradient w.r.t. input using closed-form formula
|
|
167
|
+
# For p=3: ∂L/∂x from w0 * norm(x³)
|
|
168
|
+
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
|
|
169
|
+
grad_x_3 = w0 * (
|
|
170
|
+
3.0 * X_pow2 * rstd_3 * dY_row
|
|
171
|
+
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# For p=2: ∂L/∂x from w1 * norm(x²)
|
|
175
|
+
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
|
|
176
|
+
grad_x_2 = w1 * (
|
|
177
|
+
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# For p=1: ∂L/∂x from w2 * norm(x)
|
|
181
|
+
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
|
|
182
|
+
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
|
|
183
|
+
|
|
184
|
+
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
|
|
185
|
+
dW0_acc += rstd_3 * S_3
|
|
186
|
+
dW1_acc += rstd_2 * S_2
|
|
187
|
+
dW2_acc += rstd_1 * S_1
|
|
188
|
+
|
|
189
|
+
# Total gradient
|
|
190
|
+
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
191
|
+
|
|
192
|
+
# Store gradient
|
|
193
|
+
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
|
|
194
|
+
|
|
195
|
+
# Update pointers
|
|
196
|
+
dY_ptr += dY_row_stride
|
|
197
|
+
dX_ptr += dX_row_stride
|
|
198
|
+
X_ptr += X_row_stride
|
|
199
|
+
RSTD_ptr += RSTD_row_stride
|
|
200
|
+
|
|
201
|
+
# Store accumulated gradients (scalars)
|
|
202
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
203
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
|
|
204
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
|
|
205
|
+
tl.store(dB_ptr + row_block_id, dB_acc)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def poly_norm_forward(X, W, B, eps=1e-6):
|
|
209
|
+
"""
|
|
210
|
+
PolyNorm Forward Pass
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
X: input tensor of shape (*, H) where H is hidden dimension
|
|
214
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
215
|
+
B: bias scalar tensor
|
|
216
|
+
eps: epsilon for numerical stability
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Y: output tensor of same shape as X
|
|
220
|
+
X: reshaped input (for backward)
|
|
221
|
+
RSTD: cached rstd values (for backward)
|
|
222
|
+
BLOCK_SIZE: block size used
|
|
223
|
+
num_warps: number of warps used
|
|
224
|
+
"""
|
|
225
|
+
shape = X.shape
|
|
226
|
+
dim = shape[-1]
|
|
227
|
+
X = X.view(-1, dim)
|
|
228
|
+
n_rows, n_cols = X.shape
|
|
229
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
230
|
+
|
|
231
|
+
# RSTD is to cache rstd for each row
|
|
232
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
233
|
+
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
|
|
234
|
+
|
|
235
|
+
# Check constraints
|
|
236
|
+
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
|
|
237
|
+
assert B.numel() == 1, "Bias must be a scalar"
|
|
238
|
+
|
|
239
|
+
# XPU-specific optimization
|
|
240
|
+
kernel_args = {}
|
|
241
|
+
if X.device.type == "xpu":
|
|
242
|
+
kernel_args["grf_mode"] = "large"
|
|
243
|
+
|
|
244
|
+
# Launch kernel
|
|
245
|
+
_poly_norm_forward_kernel[(n_rows,)](
|
|
246
|
+
Y,
|
|
247
|
+
Y.stride(0),
|
|
248
|
+
X,
|
|
249
|
+
X.stride(0),
|
|
250
|
+
W,
|
|
251
|
+
B,
|
|
252
|
+
RSTD,
|
|
253
|
+
RSTD.stride(0),
|
|
254
|
+
n_cols,
|
|
255
|
+
eps,
|
|
256
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
257
|
+
num_warps=num_warps,
|
|
258
|
+
**kernel_args,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
265
|
+
"""
|
|
266
|
+
PolyNorm Backward Pass
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
dY: gradient of output
|
|
270
|
+
X: input tensor (already reshaped to 2D)
|
|
271
|
+
W: weight tensor
|
|
272
|
+
RSTD: cached rstd values from forward
|
|
273
|
+
BLOCK_SIZE: block size from forward
|
|
274
|
+
num_warps: number of warps from forward
|
|
275
|
+
in_place: whether to in-place modify dY to store dX (saves memory)
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
dX: gradient w.r.t. input
|
|
279
|
+
dW: gradient w.r.t. weight
|
|
280
|
+
dB: gradient w.r.t. bias
|
|
281
|
+
"""
|
|
282
|
+
shape = dY.shape
|
|
283
|
+
dim = shape[-1]
|
|
284
|
+
dY = dY.view(-1, dim)
|
|
285
|
+
n_rows, n_cols = dY.shape
|
|
286
|
+
|
|
287
|
+
# Get number of SMs for parallelization
|
|
288
|
+
import math
|
|
289
|
+
|
|
290
|
+
sm_count = 1
|
|
291
|
+
if X.device.type == "cuda":
|
|
292
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
293
|
+
elif X.device.type == "xpu":
|
|
294
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
295
|
+
elif X.device.type == "npu":
|
|
296
|
+
sm_count = get_npu_multi_processor_count()
|
|
297
|
+
|
|
298
|
+
# Allocate or reuse gradients
|
|
299
|
+
if in_place is True:
|
|
300
|
+
dX = dY
|
|
301
|
+
else:
|
|
302
|
+
dX = torch.zeros_like(dY)
|
|
303
|
+
|
|
304
|
+
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
|
|
305
|
+
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
|
|
306
|
+
|
|
307
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
308
|
+
grid = (sm_count,)
|
|
309
|
+
|
|
310
|
+
# XPU-specific optimization
|
|
311
|
+
kernel_args = {}
|
|
312
|
+
if X.device.type == "xpu":
|
|
313
|
+
kernel_args["grf_mode"] = "large"
|
|
314
|
+
|
|
315
|
+
# Launch backward kernel
|
|
316
|
+
_poly_norm_backward_kernel[grid](
|
|
317
|
+
dY,
|
|
318
|
+
dY.stride(0),
|
|
319
|
+
dX,
|
|
320
|
+
dX.stride(0),
|
|
321
|
+
X,
|
|
322
|
+
X.stride(0),
|
|
323
|
+
W,
|
|
324
|
+
RSTD,
|
|
325
|
+
RSTD.stride(0),
|
|
326
|
+
_dW,
|
|
327
|
+
_dW.stride(0),
|
|
328
|
+
_dB,
|
|
329
|
+
n_rows,
|
|
330
|
+
n_cols,
|
|
331
|
+
rows_per_program,
|
|
332
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
|
+
num_warps=num_warps,
|
|
334
|
+
**kernel_args,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Reduce gradients across SMs
|
|
338
|
+
dX = dX.view(*shape)
|
|
339
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
340
|
+
dB = _dB.sum().to(W.dtype)
|
|
341
|
+
|
|
342
|
+
return dX, dW, dB
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class LigerPolyNormFunction(torch.autograd.Function):
|
|
346
|
+
"""
|
|
347
|
+
PolyNorm Function with forward and backward pass
|
|
348
|
+
|
|
349
|
+
PolyNorm formula:
|
|
350
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
351
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
352
|
+
|
|
353
|
+
Backward uses closed-form gradient:
|
|
354
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
@ensure_contiguous
|
|
359
|
+
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
|
|
360
|
+
"""
|
|
361
|
+
Args:
|
|
362
|
+
X: input tensor of shape (B, T, H) or (BxT, H)
|
|
363
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
364
|
+
B: bias scalar
|
|
365
|
+
eps: epsilon for numerical stability
|
|
366
|
+
in_place: whether to in-place modify grad_output in backward (saves memory)
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Y: output tensor of same shape as X
|
|
370
|
+
"""
|
|
371
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
|
|
372
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
373
|
+
ctx.num_warps = num_warps
|
|
374
|
+
ctx.in_place = in_place
|
|
375
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
376
|
+
return Y
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
@ensure_contiguous
|
|
380
|
+
def backward(ctx, grad_output):
|
|
381
|
+
"""
|
|
382
|
+
Args:
|
|
383
|
+
grad_output: gradient of output
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
dX, dW, dB: gradients w.r.t. X, W, B
|
|
387
|
+
"""
|
|
388
|
+
X, W, RSTD = ctx.saved_tensors
|
|
389
|
+
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
|
|
390
|
+
return dX, dW, dB, None, None
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
23
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
|
+
from liger_kernel.utils import is_npu_available
|
|
24
26
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
27
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
28
|
try:
|
|
27
29
|
# typical import path with dispatch available
|
|
28
30
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -349,7 +351,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
349
351
|
|
|
350
352
|
# calculate the gradient of W
|
|
351
353
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
352
|
-
|
|
354
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
355
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
353
356
|
else:
|
|
354
357
|
# here X_row is already in fp32 (see previous if block)
|
|
355
358
|
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
@@ -449,6 +452,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
449
452
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
453
|
elif X.device.type == "xpu":
|
|
451
454
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
455
|
+
elif X.device.type == "npu":
|
|
456
|
+
sm_count = get_npu_multi_processor_count()
|
|
452
457
|
|
|
453
458
|
# fp32 for numerical stability especially.
|
|
454
459
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|