liger-kernel-nightly 0.6.3.dev20251121202601__py3-none-any.whl → 0.6.3.dev20251121213521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/layer_norm.py +84 -65
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.3.dev20251121202601.dist-info → liger_kernel_nightly-0.6.3.dev20251121213521.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -85,68 +86,87 @@ def _layer_norm_forward_kernel(
|
|
|
85
86
|
@triton.jit
|
|
86
87
|
def _layer_norm_backward_kernel(
|
|
87
88
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
89
|
+
stride_x, # stride of each row in input
|
|
88
90
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
89
91
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
92
|
+
stride_mean, # stride of each row in mean
|
|
90
93
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
94
|
+
stride_rstd, # stride of each row in rstd
|
|
91
95
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
96
|
+
stride_dx, # stride of each row in input grad
|
|
92
97
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
98
|
+
stride_dw, # stride of each row in weights grad
|
|
93
99
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
100
|
+
stride_db, # stride of each row in bias grad
|
|
94
101
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
95
|
-
stride_x, # stride of each row in input
|
|
96
|
-
stride_dx, # stride of each row in input grad
|
|
97
102
|
stride_dy, # stride of each row in output grad
|
|
103
|
+
n_rows,
|
|
98
104
|
n_cols,
|
|
105
|
+
rows_per_program: tl.constexpr,
|
|
99
106
|
BLOCK_SIZE: tl.constexpr,
|
|
100
|
-
dtype: tl.constexpr,
|
|
101
|
-
atomic_dtype: tl.constexpr,
|
|
102
107
|
):
|
|
103
108
|
"""
|
|
104
109
|
References:
|
|
105
110
|
https://arxiv.org/abs/1607.06450
|
|
106
111
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
107
112
|
"""
|
|
108
|
-
|
|
113
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
114
|
+
row_start = row_block_id * rows_per_program
|
|
115
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
116
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
110
117
|
mask = cols < n_cols
|
|
111
118
|
|
|
119
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
120
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
|
|
112
122
|
# Pre-load weights once (same optimization as forward pass)
|
|
113
123
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
124
|
w_f32 = w.to(tl.float32)
|
|
115
125
|
|
|
116
126
|
# Calculate pointers for this specific row
|
|
117
|
-
row_X_ptr = X_ptr +
|
|
118
|
-
row_DX_ptr = DX_ptr +
|
|
119
|
-
row_DY_ptr = DY_ptr +
|
|
120
|
-
row_Mean_ptr = Mean_ptr +
|
|
121
|
-
row_RSTD_ptr = RSTD_ptr +
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
127
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
128
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
129
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
130
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
131
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
132
|
+
|
|
133
|
+
for _ in range(row_start, row_end):
|
|
134
|
+
# Load data for this row
|
|
135
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
136
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
mean = tl.load(row_Mean_ptr)
|
|
138
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
139
|
+
|
|
140
|
+
# Convert to fp32 for numerical stability
|
|
141
|
+
x_f32 = x.to(tl.float32)
|
|
142
|
+
dy_f32 = dy.to(tl.float32)
|
|
143
|
+
mean_f32 = mean.to(tl.float32)
|
|
144
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
145
|
+
|
|
146
|
+
# Compute backward pass for this row
|
|
147
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
148
|
+
wdy = w_f32 * dy_f32
|
|
149
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
150
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
151
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
152
|
+
|
|
153
|
+
# Store input gradient
|
|
154
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
155
|
+
|
|
156
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
157
|
+
dw = dy_f32 * x_hat
|
|
158
|
+
db = dy_f32
|
|
159
|
+
dW_row += dw
|
|
160
|
+
db_row += db
|
|
161
|
+
|
|
162
|
+
row_X_ptr += stride_x
|
|
163
|
+
row_DX_ptr += stride_dx
|
|
164
|
+
row_DY_ptr += stride_dy
|
|
165
|
+
row_Mean_ptr += stride_mean
|
|
166
|
+
row_RSTD_ptr += stride_rstd
|
|
167
|
+
|
|
168
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
169
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
150
170
|
|
|
151
171
|
|
|
152
172
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -228,31 +248,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
228
248
|
dY = dY.view(-1, dim)
|
|
229
249
|
n_rows, n_cols = dY.shape
|
|
230
250
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
251
|
+
sm_count = 1
|
|
252
|
+
if X.device.type == "cuda":
|
|
253
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
|
+
elif X.device.type == "xpu":
|
|
255
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
256
|
+
|
|
257
|
+
# fp32 for numerical stability especially.
|
|
258
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
259
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
237
260
|
|
|
238
261
|
# Calculate optimal block size and warp configuration
|
|
239
262
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
240
263
|
if n_cols > BLOCK_SIZE:
|
|
241
264
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
265
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
266
|
+
grid = (sm_count,)
|
|
242
267
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
tl.float32
|
|
246
|
-
if X.dtype == torch.float32
|
|
247
|
-
else tl.bfloat16
|
|
248
|
-
if X.dtype == torch.bfloat16
|
|
249
|
-
else tl.float16
|
|
250
|
-
if X.dtype == torch.float16
|
|
251
|
-
else tl.float32 # fallback
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
268
|
+
# Allocate gradient tensors
|
|
269
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
256
270
|
|
|
257
271
|
kernel_args = {"num_warps": num_warps}
|
|
258
272
|
# XPU-specific optimization
|
|
@@ -260,28 +274,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
260
274
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
261
275
|
|
|
262
276
|
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
-
grid = (n_rows,)
|
|
264
277
|
_layer_norm_backward_kernel[grid](
|
|
265
278
|
X,
|
|
279
|
+
X.stride(0),
|
|
266
280
|
W,
|
|
267
281
|
Mean,
|
|
282
|
+
Mean.stride(0),
|
|
268
283
|
RSTD,
|
|
284
|
+
RSTD.stride(0),
|
|
269
285
|
DX,
|
|
270
|
-
DW,
|
|
271
|
-
DB,
|
|
272
|
-
dY,
|
|
273
|
-
X.stride(0),
|
|
274
286
|
DX.stride(0),
|
|
287
|
+
_DW,
|
|
288
|
+
_DW.stride(0),
|
|
289
|
+
_DB,
|
|
290
|
+
_DB.stride(0),
|
|
291
|
+
dY,
|
|
275
292
|
dY.stride(0),
|
|
293
|
+
n_rows,
|
|
276
294
|
n_cols,
|
|
295
|
+
rows_per_program=rows_per_program,
|
|
277
296
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
278
|
-
dtype=triton_dtype,
|
|
279
|
-
atomic_dtype=atomic_dtype,
|
|
280
297
|
**kernel_args,
|
|
281
298
|
)
|
|
282
299
|
|
|
283
300
|
DX = DX.view(*shape)
|
|
284
|
-
|
|
301
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
302
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
303
|
+
return DX, DW, DB
|
|
285
304
|
|
|
286
305
|
|
|
287
306
|
class LigerLayerNormFunction(torch.autograd.Function):
|
|
@@ -28,7 +28,7 @@ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2wogg
|
|
|
28
28
|
liger_kernel/ops/grpo_loss.py,sha256=2SyOujtF9I3xiNo4wFf4s6MeiDotE_qeYfRWgj_bOBE,9573
|
|
29
29
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
|
30
30
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
|
31
|
-
liger_kernel/ops/layer_norm.py,sha256=
|
|
31
|
+
liger_kernel/ops/layer_norm.py,sha256=OMaex1MDsM9kaFs0-q5Pnx3DrMVjongQoZ5-iFIOy00,10523
|
|
32
32
|
liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd8s,8223
|
|
33
33
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
34
34
|
liger_kernel/ops/poly_norm.py,sha256=MLgI8Ea93fugKibHCUauQ2ASYVXCvpPZe5v3kQZU6po,11152
|
|
@@ -110,9 +110,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
110
110
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
111
111
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
112
112
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
113
|
-
liger_kernel_nightly-0.6.3.
|
|
114
|
-
liger_kernel_nightly-0.6.3.
|
|
115
|
-
liger_kernel_nightly-0.6.3.
|
|
116
|
-
liger_kernel_nightly-0.6.3.
|
|
117
|
-
liger_kernel_nightly-0.6.3.
|
|
118
|
-
liger_kernel_nightly-0.6.3.
|
|
113
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
114
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/METADATA,sha256=__p46-m1Fnwjp4mS78P8H7l3vFCXzyqw-MPEgfqbZZA,25238
|
|
115
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
116
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
117
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
118
|
+
liger_kernel_nightly-0.6.3.dev20251121213521.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|